提交 dc2c0c01 编写于 作者: H hjchen2

Revert int8 gemm

上级 ad5087c9
......@@ -107,15 +107,9 @@ inline void GemmConv(const ConvParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
if (param.Input()->type() == typeid(int8_t)) {
math::matmul_int8(filter_slice, false, col_matrix, false,
math::matmul<Itype>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
} else {
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
}
}
}
}
......
......@@ -73,8 +73,8 @@ void MulCompute(const MulParam<CPU> &param) {
}
if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>();
math::matmul_int8(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
math::matmul<int8_t>(x_matrix, false, y_matrix, false,
static_cast<int8_t>(1), out, static_cast<int8_t>(0));
} else {
out->mutable_data<float>();
......
......@@ -23,12 +23,10 @@ limitations under the License. */
#if __aarch64__
#define MR_INT8 4
#define NR_INT8 2
#define MR 6
#define NR 16
#else
#define MR_INT8 4
#define NR_INT8 2
#define MR 6
#define NR 8
#endif
......@@ -195,58 +193,52 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int small block inner product
void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
// 8 bits int inner product
void InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, int32_t *C,
int32_t ldc, bool relu);
void InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, int8_t *C,
int32_t ldc, bool relu, int32_t *bias);
void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha,
const int8_t *a, const int8_t *b, int8_t beta,
int32_t *c, int32_t *C, int32_t ldc, bool relu,
int8_t *bias);
// 8 bits int pack function
void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer);
void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer);
// 8 bits int matrix product
void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta, int32_t *C,
int32_t ldc, bool relu, int32_t *bias);
void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta, int8_t *C,
int32_t ldc, bool relu, int32_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
int32_t *C, int32_t ldc, bool relu, int32_t *bias);
void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C,
int32_t ldc, bool relu, int8_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta,
int32_t *C, int32_t ldc, bool relu, int8_t *bias);
// 8 bits int write back
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B
void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc);
// C = A * B + bias, scale * relu(C)
void WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale);
// C = A * B + bias, scale * C
void WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale);
// C = A * B + C
void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias
void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
private:
int MC = 0;
......@@ -262,7 +254,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int
int8_t *packedA_int8;
int8_t *packedB_int8;
int32_t *packedC_int32;
int32_t *packedC_int8;
int8_t *zero_int8;
};
......
此差异已折叠。
......@@ -28,10 +28,10 @@ namespace operators {
namespace math {
// 8 bits int matrix product (m*k x k*n)
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb,
float beta, int32_t *C, int32_t ldc, bool relu,
int32_t *bias) {
int8_t beta, int32_t *C, int32_t ldc, bool relu,
int8_t *bias) {
#ifdef _OPENMP
int32_t max_threads = omp_get_max_threads();
#else
......@@ -39,11 +39,10 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
#endif
int32_t L1 = 64 / max_threads * 1024;
const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
KC = k;
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * KC);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(int8_t));
......@@ -55,14 +54,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
}
// 补齐 B
NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8;
NC = (n + NR - 1) / NR * NR;
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
PackMatrixB_omp_8c(KC, n, n % NR, B, ldb, packedB_int8);
#endif
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads));
......@@ -70,11 +69,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
// 对 B 分块
NC = L1 / (KC * sizeof(int8_t));
if (NC == 0) {
NC = NR_INT8;
NC = NR;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
NC = (NC + NR - 1) / NR * NR;
}
// 补齐 A
MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8;
......@@ -84,12 +83,12 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
PackMatrixA_omp_4r(m, KC, m % MR_INT8, A, lda, packedA_int8);
#endif
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads));
}
packedC_int32 = static_cast<int32_t *>(
packedC_int8 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads));
if (m > n) {
......@@ -104,19 +103,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int32_t mc;
mc = s_min(m - i, MC);
int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads;
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif
// InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta,
// local_C,
// &C(i, 0), ldc, relu, bias + i);
if (bias == nullptr) {
InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu);
}
InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu, bias + i);
}
} else {
#pragma omp parallel for
......@@ -129,25 +123,20 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int32_t nc;
nc = s_min(n - j, NC);
int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads;
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, local_B);
#endif
// InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta,
// local_C,
// &C(0, j), ldc, relu, bias);
if (bias == nullptr) {
InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu);
}
InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu, bias);
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(packedC_int8);
paddle_mobile::memory::Free(zero_int8);
}
......@@ -155,7 +144,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
#pragma omp parallel for
for (int32_t j = 0; j < j_length; j += 8) {
for (int32_t j = 0; j < j_length; j += NR) {
int8_t *local_buffer = buffer + j * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j);
......@@ -190,7 +179,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
for (int32_t j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int32_t j = n; j < j_length + 8; ++j) {
for (int32_t j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
}
}
......@@ -199,9 +188,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
const int i_length = m - m_tail;
#pragma omp parallel for
for (int32_t i = 0; i < i_length; i += 4) {
for (int32_t i = 0; i < i_length; i += MR_INT8) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
......@@ -232,7 +221,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
default:
break;
}
for (int32_t j = 0; j < k; ++j) {
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
......@@ -241,232 +230,6 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
}
}
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
int8_t *local_buffer = buffer + i * KC;
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
int8_t *local_buffer = buffer + i_length * KC;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
break;
default:
break;
}
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t j = 0; j < j_length; j += 2) {
int8_t *local_buffer = buffer + j * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j);
const int8_t *b1 = &B((i << 4), j + 1);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b1;
b1 += ldb;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j);
const int8_t *b1 = &B((k_count << 4), j + 1);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j_length);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = 0;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j_length);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -28,12 +28,7 @@ template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false,
float *bias = nullptr);
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu = false,
int32_t *bias = nullptr);
T *bias = nullptr);
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
......@@ -20,10 +20,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
int32_t *bias) {
template <>
void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
int8_t alpha, framework::Tensor *matrix_out, int8_t beta,
bool relu, int8_t *bias) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -51,45 +52,21 @@ void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
}
#ifdef _OPENMP
if (bias != nullptr) {
// TODO(wzzju): gemm.Sgemm_omp_with_bias, now use single thread instead.
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias);
} else {
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
#else
if (bias != nullptr) {
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
#endif
} else {
#ifdef _OPENMP
if (bias != nullptr) {
// TODO(wzzju): gemm.Sgemm_omp_with_bias, now use single thread instead.
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias);
} else {
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
#else
if (bias != nullptr) {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(),
N, relu, bias);
}
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(), N,
relu, bias);
#endif
}
}
......
......@@ -28,7 +28,7 @@ limitations under the License. */
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(4);
paddle_mobile.SetThreadNum(8);
Tensor aa, bb, cc;
auto aaptr = aa.mutable_data<float>({m, k});
auto bbptr = bb.mutable_data<float>({k, n});
......@@ -44,12 +44,10 @@ int main() {
ccptr[i] = 2;
}
Tensor aa_int8, bb_int8, cc_int32, cc_int8;
Tensor aa_int8, bb_int8, cc_int8;
auto aaptr_int8 = aa_int8.mutable_data<int8_t>({m, k});
auto bbptr_int8 = bb_int8.mutable_data<int8_t>({k, n});
auto ccptr_int32 = cc_int32.mutable_data<int32_t>({m, n});
auto ccptr_int8 = cc_int8.mutable_data<int8_t>({m, n});
int32_t* bias_data = new int32_t[m];
auto ccptr_int8 = cc_int8.mutable_data<int32_t>({m, n});
for (int i = 0; i < m * k; ++i) {
aaptr_int8[i] = static_cast<int8_t>(2);
......@@ -58,11 +56,7 @@ int main() {
bbptr_int8[i] = static_cast<int8_t>(2);
}
for (int i = 0; i < m * n; ++i) {
ccptr_int32[i] = static_cast<int32_t>(2);
}
for (int i = 0; i < m; ++i) {
bias_data[i] = 2;
ccptr_int8[i] = static_cast<int32_t>(2);
}
// float
......@@ -82,41 +76,22 @@ int main() {
auto time2 = time();
std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n";
// int8_t without bias
// int8_t
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, nullptr);
paddle_mobile::operators::math::matmul<int8_t>(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8,
static_cast<int8_t>(0), false, nullptr);
}
auto time3 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, nullptr);
paddle_mobile::operators::math::matmul<int8_t>(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8,
static_cast<int8_t>(0), false, nullptr);
}
auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
// int8_t with bias&relu
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
}
auto time5 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
}
auto time6 = time();
std::cout << "int8_t gemm_with_bias_relu cost :"
<< time_diff(time5, time6) / 10 << "ms\n";
delete[] bias_data;
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册