提交 0c34021a 编写于 作者: Z Zhen Wang

faster gemm_int8, max speedup can be 2(int8 / float), add gemm_with_bias and...

faster gemm_int8, max speedup can be 2(int8 / float), add gemm_with_bias and add gemm_with_relu_bias.
上级 3c642e35
...@@ -107,9 +107,15 @@ inline void GemmConv(const ConvParam<CPU> &param) { ...@@ -107,9 +107,15 @@ inline void GemmConv(const ConvParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Itype>(filter_slice, false, col_matrix, false, if (param.Input()->type() == typeid(int8_t)) {
math::matmul_int8(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(0)); static_cast<float>(0));
} else {
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
}
} }
} }
} }
......
...@@ -73,8 +73,8 @@ void MulCompute(const MulParam<CPU> &param) { ...@@ -73,8 +73,8 @@ void MulCompute(const MulParam<CPU> &param) {
} }
if (param.InputX()->type() == typeid(int8_t)) { if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>(); out->mutable_data<int32_t>();
math::matmul<int8_t>(x_matrix, false, y_matrix, false, math::matmul_int8(x_matrix, false, y_matrix, false, static_cast<float>(1),
static_cast<int8_t>(1), out, static_cast<int8_t>(0)); out, static_cast<float>(0));
} else { } else {
out->mutable_data<float>(); out->mutable_data<float>();
......
...@@ -23,10 +23,12 @@ limitations under the License. */ ...@@ -23,10 +23,12 @@ limitations under the License. */
#if __aarch64__ #if __aarch64__
#define MR_INT8 4 #define MR_INT8 4
#define NR_INT8 2
#define MR 6 #define MR 6
#define NR 16 #define NR 16
#else #else
#define MR_INT8 4 #define MR_INT8 4
#define NR_INT8 2
#define MR 6 #define MR 6
#define NR 8 #define NR 8
#endif #endif
...@@ -193,52 +195,58 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -193,52 +195,58 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int small block inner product // 8 bits int small block inner product
void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc); int32_t ldc);
void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc); int32_t ldc);
// 8 bits int inner product // 8 bits int inner product
void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, void InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *a, const int8_t *b, int8_t beta, const int8_t *b, float beta, int32_t *c, int32_t *C,
int32_t *c, int32_t *C, int32_t ldc, bool relu, int32_t ldc, bool relu);
int8_t *bias); void InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, int8_t *C,
int32_t ldc, bool relu, int32_t *bias);
// 8 bits int pack function // 8 bits int pack function
void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer); int32_t lda, int8_t *buffer);
void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer); int32_t lda, int8_t *buffer);
void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer); int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer); int32_t lda, int8_t *buffer);
void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer); int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer);
void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer);
// 8 bits int matrix product // 8 bits int matrix product
void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C, int32_t lda, const int8_t *B, int32_t ldb, float beta, int32_t *C,
int32_t ldc, bool relu, int8_t *bias); int32_t ldc, bool relu, int32_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t lda, const int8_t *B, int32_t ldb, float beta, int8_t *C,
int32_t *C, int32_t ldc, bool relu, int8_t *bias); int32_t ldc, bool relu, int32_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
int32_t *C, int32_t ldc, bool relu, int32_t *bias);
// 8 bits int write back // 8 bits int write back
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B // C = A * B
void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc); void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc);
// C = A * B + C // C = A * B + bias, scale * relu(C)
void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, void WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc); int32_t ldc, int32_t *bias, float scale);
// C = A * B + bias // C = A * B + bias, scale * C
void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, void WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int8_t *bias); int32_t ldc, int32_t *bias, float scale);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
private: private:
int MC = 0; int MC = 0;
...@@ -254,7 +262,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -254,7 +262,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int // 8 bits int
int8_t *packedA_int8; int8_t *packedA_int8;
int8_t *packedB_int8; int8_t *packedB_int8;
int32_t *packedC_int8; int32_t *packedC_int32;
int8_t *zero_int8; int8_t *zero_int8;
}; };
......
此差异已折叠。
...@@ -28,10 +28,10 @@ namespace operators { ...@@ -28,10 +28,10 @@ namespace operators {
namespace math { namespace math {
// 8 bits int matrix product (m*k x k*n) // 8 bits int matrix product (m*k x k*n)
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb,
int8_t beta, int32_t *C, int32_t ldc, bool relu, float beta, int32_t *C, int32_t ldc, bool relu,
int8_t *bias) { int32_t *bias) {
#ifdef _OPENMP #ifdef _OPENMP
int32_t max_threads = omp_get_max_threads(); int32_t max_threads = omp_get_max_threads();
#else #else
...@@ -39,10 +39,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -39,10 +39,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
#endif #endif
int32_t L1 = 64 / max_threads * 1024; int32_t L1 = 64 / max_threads * 1024;
KC = k; const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
zero_int8 = zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * KC); memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
if (m > n) { if (m > n) {
// 对 A 分块 // 对 A 分块
MC = L1 / (KC * sizeof(int8_t)); MC = L1 / (KC * sizeof(int8_t));
...@@ -54,14 +55,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -54,14 +55,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
} }
// 补齐 B // 补齐 B
NC = (n + NR - 1) / NR * NR; NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8;
packedB_int8 = static_cast<int8_t *>( packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO(wzzju)
#else #else
PackMatrixB_omp_8c(KC, n, n % NR, B, ldb, packedB_int8); PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif #endif
packedA_int8 = static_cast<int8_t *>( packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads)); paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads));
...@@ -69,11 +70,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -69,11 +70,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
// 对 B 分块 // 对 B 分块
NC = L1 / (KC * sizeof(int8_t)); NC = L1 / (KC * sizeof(int8_t));
if (NC == 0) { if (NC == 0) {
NC = NR; NC = NR_INT8;
} else { } else {
int32_t nblock_num = (n + NC - 1) / NC; int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num; NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR; NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
} }
// 补齐 A // 补齐 A
MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8; MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8;
...@@ -83,12 +84,12 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -83,12 +84,12 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO(wzzju)
#else #else
PackMatrixA_omp_4r(m, KC, m % MR_INT8, A, lda, packedA_int8); PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif #endif
packedB_int8 = static_cast<int8_t *>( packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads));
} }
packedC_int8 = static_cast<int32_t *>( packedC_int32 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads));
if (m > n) { if (m > n) {
...@@ -103,14 +104,19 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -103,14 +104,19 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
int32_t mc; int32_t mc;
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
int8_t *local_A = packedA_int8 + MC * KC * local_threads; int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO(wzzju)
#else #else
PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, local_A); PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif #endif
InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C, // InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta,
&C(i, 0), ldc, relu, bias + i); // local_C,
// &C(i, 0), ldc, relu, bias + i);
if (bias == nullptr) {
InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu);
}
} }
} else { } else {
#pragma omp parallel for #pragma omp parallel for
...@@ -123,20 +129,25 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -123,20 +129,25 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
int32_t nc; int32_t nc;
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
int8_t *local_B = packedB_int8 + KC * NC * local_threads; int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO(wzzju)
#else #else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, local_B); PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif #endif
InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C, // InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta,
&C(0, j), ldc, relu, bias); // local_C,
// &C(0, j), ldc, relu, bias);
if (bias == nullptr) {
InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu);
}
} }
} }
paddle_mobile::memory::Free(packedA_int8); paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8); paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int8); paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(zero_int8); paddle_mobile::memory::Free(zero_int8);
} }
...@@ -144,7 +155,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, ...@@ -144,7 +155,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) { const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail; const int32_t j_length = n - n_tail;
#pragma omp parallel for #pragma omp parallel for
for (int32_t j = 0; j < j_length; j += NR) { for (int32_t j = 0; j < j_length; j += 8) {
int8_t *local_buffer = buffer + j * k; int8_t *local_buffer = buffer + j * k;
for (int32_t i = 0; i < k; ++i) { for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j); const int8_t *b0 = &B(i, j);
...@@ -179,7 +190,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, ...@@ -179,7 +190,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
for (int32_t j = j_length; j < n; ++j) { for (int32_t j = j_length; j < n; ++j) {
*local_buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int32_t j = n; j < j_length + NR; ++j) { for (int32_t j = n; j < j_length + 8; ++j) {
*local_buffer++ = 0; *local_buffer++ = 0;
} }
} }
...@@ -188,9 +199,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, ...@@ -188,9 +199,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) { const int8_t *A, int32_t lda, int8_t *buffer) {
const int i_length = m - m_tail; const int32_t i_length = m - m_tail;
#pragma omp parallel for #pragma omp parallel for
for (int32_t i = 0; i < i_length; i += MR_INT8) { for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda; const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda; const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda; const int8_t *a2 = A + (i + 2) * lda;
...@@ -221,7 +232,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, ...@@ -221,7 +232,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
default: default:
break; break;
} }
for (int j = 0; j < k; ++j) { for (int32_t j = 0; j < k; ++j) {
*local_buffer++ = *a0++; *local_buffer++ = *a0++;
*local_buffer++ = *a1++; *local_buffer++ = *a1++;
*local_buffer++ = *a2++; *local_buffer++ = *a2++;
...@@ -230,6 +241,232 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, ...@@ -230,6 +241,232 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
} }
} }
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
int8_t *local_buffer = buffer + i * KC;
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
int8_t *local_buffer = buffer + i_length * KC;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
break;
default:
break;
}
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t j = 0; j < j_length; j += 2) {
int8_t *local_buffer = buffer + j * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j);
const int8_t *b1 = &B((i << 4), j + 1);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b1;
b1 += ldb;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j);
const int8_t *b1 = &B((k_count << 4), j + 1);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j_length);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = 0;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j_length);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -28,7 +28,12 @@ template <typename T> ...@@ -28,7 +28,12 @@ template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a, void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha, const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false, framework::Tensor *matrix_out, T beta, bool relu = false,
T *bias = nullptr); float *bias = nullptr);
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu = false,
int32_t *bias = nullptr);
template <typename T> template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
...@@ -20,11 +20,10 @@ limitations under the License. */ ...@@ -20,11 +20,10 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
template <> void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha,
const framework::Tensor &matrix_b, bool trans_b, framework::Tensor *matrix_out, float beta, bool relu,
int8_t alpha, framework::Tensor *matrix_out, int8_t beta, int32_t *bias) {
bool relu, int8_t *bias) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -52,21 +51,45 @@ void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -52,21 +51,45 @@ void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a,
} }
#ifdef _OPENMP #ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta, if (bias != nullptr) {
matrix_out->data<int32_t>(), N, relu, bias); // TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead.
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias);
} else {
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
#else #else
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta, if (bias != nullptr) {
matrix_out->data<int32_t>(), N, relu, bias); gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
#endif #endif
} else { } else {
#ifdef _OPENMP #ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K, if (bias != nullptr) {
matrix_b.data<int8_t>(), N, beta, // TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead.
matrix_out->data<int32_t>(), N, relu, bias); gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias);
} else {
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
#else #else
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K, if (bias != nullptr) {
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(), N, gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
relu, bias); matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(),
N, relu, bias);
}
#endif #endif
} }
} }
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <climits>
#include <cstdlib> #include <cstdlib>
#include <ctime> #include <ctime>
#include <iostream> #include <iostream>
...@@ -54,6 +55,30 @@ void print_matirx(int m, int n, int ldc, int8_t *c) { ...@@ -54,6 +55,30 @@ void print_matirx(int m, int n, int ldc, int8_t *c) {
std::cout << std::endl; std::cout << std::endl;
} }
int32_t qadd_int32(int32_t l, int32_t r) {
int64_t res = static_cast<int64_t>(l) + static_cast<int64_t>(r);
if (res > INT_MAX)
return INT_MAX;
else if (res < INT_MIN)
return INT_MIN;
else
return static_cast<int32_t>(res);
}
int8_t qscale_int32(int32_t v, float scale) {
float res = static_cast<float>(v) * scale;
if (res > 0)
res = std::floor(res);
else if (res < 0)
res = std::ceil(res); // round to zero
if (res > 127)
return static_cast<int8_t>(127);
else if (res < -127)
return static_cast<int8_t>(-127);
else
return static_cast<int8_t>(res);
}
int do_sgemm(int m, int n, int k, bool relu, int pr) { int do_sgemm(int m, int n, int k, bool relu, int pr) {
int lda = k; int lda = k;
int ldb = n; int ldb = n;
...@@ -126,10 +151,98 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) { ...@@ -126,10 +151,98 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) {
return 0; return 0;
} }
int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr) {
int lda = k;
int ldb = n;
int ldc = n;
float scale = 0.00628;
default_random_engine e;
uniform_int_distribution<int8_t> pixel(-127, 127);
int8_t *a = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k));
int8_t *b = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n));
int8_t *c = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * m * n));
int8_t *c1 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * m * n));
int32_t *bias =
static_cast<int32_t *>(paddle_mobile::memory::Alloc(sizeof(int32_t) * m));
for (int i = 0; i < m * k; ++i) {
a[i] = pixel(e);
}
for (int i = 0; i < k * n; ++i) {
b[i] = pixel(e);
}
for (int i = 0; i < m; ++i) {
bias[i] = static_cast<int32_t>(pixel(e));
}
for (int i = 0; i < m; ++i) {
int32_t bias_v = bias[i];
for (int j = 0; j < n; ++j) {
int32_t r = 0;
for (int p = 0; p < k; p++) {
r += static_cast<int32_t>(a(i, p)) * static_cast<int32_t>(b(p, j));
}
r = qadd_int32(r, bias_v);
if (relu) r = std::max(0, r);
c1(i, j) = qscale_int32(r, scale);
}
}
paddle_mobile::operators::math::Gemm gemm;
#ifdef _OPENMP
// TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead.
gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast<float>(0), c, ldc,
relu, bias);
#else
gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast<float>(0), c, ldc,
relu, bias);
#endif
int eq = 0;
int neq = 0;
for (int i = 0; i < m * n; ++i) {
if (c[i] == c1[i]) {
++eq;
} else {
++neq;
}
}
if (pr > 0) {
std::cout << "A:" << std::endl;
print_matirx(m, k, lda, a);
std::cout << "B:" << std::endl;
print_matirx(k, n, ldb, b);
std::cout << "Bias:" << std::endl;
print_matirx(m, 1, 1, bias);
std::cout << "C:" << std::endl;
print_matirx(m, n, ldc, c);
std::cout << "C1:" << std::endl;
print_matirx(m, n, ldc, c1);
}
std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu
<< " eq=" << eq << " neq=" << neq << std::endl;
paddle_mobile::memory::Free(a);
paddle_mobile::memory::Free(b);
paddle_mobile::memory::Free(c);
paddle_mobile::memory::Free(c1);
paddle_mobile::memory::Free(bias);
return 0;
}
int main() { int main() {
#ifdef _OPENMP #ifdef _OPENMP
omp_set_num_threads(8); omp_set_num_threads(4);
#endif #endif
std::cout << "\n\n******************************************************\n\n"
<< std::endl;
std::cout << "Test gemm without bias:" << std::endl;
do_sgemm(9, 9, 9, false, 1); do_sgemm(9, 9, 9, false, 1);
do_sgemm(10, 6, 12, false, 0); do_sgemm(10, 6, 12, false, 0);
do_sgemm(512, 256, 384, false, 0); do_sgemm(512, 256, 384, false, 0);
...@@ -140,5 +253,31 @@ int main() { ...@@ -140,5 +253,31 @@ int main() {
do_sgemm(333, 797, 939, false, 0); do_sgemm(333, 797, 939, false, 0);
do_sgemm(1024, 1024, 1024, false, 0); do_sgemm(1024, 1024, 1024, false, 0);
std::cout << "\n\n******************************************************\n\n"
<< std::endl;
std::cout << "Test gemm with bias:" << std::endl;
do_sgemm_with_bias(9, 9, 9, false, 1);
do_sgemm_with_bias(10, 6, 12, false, 0);
do_sgemm_with_bias(512, 256, 384, false, 0);
do_sgemm_with_bias(1366, 768, 256, false, 0);
do_sgemm_with_bias(1255, 755, 333, false, 0);
do_sgemm_with_bias(599, 1133, 393, false, 0);
do_sgemm_with_bias(777, 555, 999, false, 0);
do_sgemm_with_bias(333, 797, 939, false, 0);
do_sgemm_with_bias(1024, 1024, 1024, false, 0);
std::cout << "\n\n******************************************************\n\n"
<< std::endl;
std::cout << "Test gemm with relu and bias:" << std::endl;
do_sgemm_with_bias(9, 9, 9, true, 1);
do_sgemm_with_bias(10, 6, 12, true, 0);
do_sgemm_with_bias(512, 256, 384, true, 0);
do_sgemm_with_bias(1366, 768, 256, true, 0);
do_sgemm_with_bias(1255, 755, 333, true, 0);
do_sgemm_with_bias(599, 1133, 393, true, 0);
do_sgemm_with_bias(777, 555, 999, true, 0);
do_sgemm_with_bias(333, 797, 939, true, 0);
do_sgemm_with_bias(1024, 1024, 1024, true, 0);
return 0; return 0;
} }
...@@ -28,7 +28,7 @@ limitations under the License. */ ...@@ -28,7 +28,7 @@ limitations under the License. */
int main() { int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile; paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(8); paddle_mobile.SetThreadNum(4);
Tensor aa, bb, cc; Tensor aa, bb, cc;
auto aaptr = aa.mutable_data<float>({m, k}); auto aaptr = aa.mutable_data<float>({m, k});
auto bbptr = bb.mutable_data<float>({k, n}); auto bbptr = bb.mutable_data<float>({k, n});
...@@ -44,10 +44,12 @@ int main() { ...@@ -44,10 +44,12 @@ int main() {
ccptr[i] = 2; ccptr[i] = 2;
} }
Tensor aa_int8, bb_int8, cc_int8; Tensor aa_int8, bb_int8, cc_int32, cc_int8;
auto aaptr_int8 = aa_int8.mutable_data<int8_t>({m, k}); auto aaptr_int8 = aa_int8.mutable_data<int8_t>({m, k});
auto bbptr_int8 = bb_int8.mutable_data<int8_t>({k, n}); auto bbptr_int8 = bb_int8.mutable_data<int8_t>({k, n});
auto ccptr_int8 = cc_int8.mutable_data<int32_t>({m, n}); auto ccptr_int32 = cc_int32.mutable_data<int32_t>({m, n});
auto ccptr_int8 = cc_int8.mutable_data<int8_t>({m, n});
int32_t* bias_data = new int32_t[m];
for (int i = 0; i < m * k; ++i) { for (int i = 0; i < m * k; ++i) {
aaptr_int8[i] = static_cast<int8_t>(2); aaptr_int8[i] = static_cast<int8_t>(2);
...@@ -56,7 +58,11 @@ int main() { ...@@ -56,7 +58,11 @@ int main() {
bbptr_int8[i] = static_cast<int8_t>(2); bbptr_int8[i] = static_cast<int8_t>(2);
} }
for (int i = 0; i < m * n; ++i) { for (int i = 0; i < m * n; ++i) {
ccptr_int8[i] = static_cast<int32_t>(2); ccptr_int32[i] = static_cast<int32_t>(2);
}
for (int i = 0; i < m; ++i) {
bias_data[i] = 2;
} }
// float // float
...@@ -76,22 +82,41 @@ int main() { ...@@ -76,22 +82,41 @@ int main() {
auto time2 = time(); auto time2 = time();
std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n";
// int8_t // int8_t without bias
// warm-up 10 times // warm-up 10 times
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t>( paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8, aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<int8_t>(0), false, nullptr); static_cast<float>(0), false, nullptr);
} }
auto time3 = time(); auto time3 = time();
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t>( paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8, aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<int8_t>(0), false, nullptr); static_cast<float>(0), false, nullptr);
} }
auto time4 = time(); auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
// int8_t with bias&relu
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
}
auto time5 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
}
auto time6 = time();
std::cout << "int8_t gemm_with_bias_relu cost :"
<< time_diff(time5, time6) / 10 << "ms\n";
delete[] bias_data;
return 0; return 0;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册