提交 10afce40 编写于 作者: Z Zhen Wang

add 8 bit gemm

上级 0dbc5235
......@@ -142,6 +142,61 @@ void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
}
}
// 8位 int PackMatrixA函数
void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const int8_t *A, int lda,
int8_t *buffer) {
const int i_length = m - m_tail;
for (int i = 0; i < i_length; i += MR) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
const int8_t *a4 = A + (i + 4) * lda;
const int8_t *a5 = A + (i + 5) * lda;
int8_t *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
const int8_t *a4 = a0 + 4 * lda;
const int8_t *a5 = a0 + 5 * lda;
int8_t *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
case 4:
a4 = zero_int8;
case 5:
a5 = zero_int8;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
}
void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const int i_length = m - m_tail;
......@@ -384,6 +439,48 @@ void Gemm::PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
}
}
// 8位 int PackMatrixB函数
void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const int8_t *B, int ldb,
int8_t *buffer) {
const int j_length = n - n_tail;
for (int j = 0; j < j_length; j += NR) {
int8_t *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j);
#if __ARM_NEON
asm volatile(
// "pld [%[b0]] \n\t"
"vld1.s8 {d0}, [%[b0]] \n\t"
"vst1.s8 {d0}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "q0");
#else
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
#endif // __ARM_NEON
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 将B矩阵分块复制到连续内存(RowMajor)
void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
......@@ -648,6 +745,42 @@ void Gemm::InnerKernel(int mc, int nc, float alpha, const float *a,
}
}
// 8位 int 分块矩阵乘法
void Gemm::InnerKernelWithBias(int mc, int nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int *c, int *C,
int ldc, bool relu, int8_t *bias) {
#pragma omp parallel for
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
}
}
if (alpha != 1) {
WriteWithAlphaBeta(mc, nc, c, C, ldc);
return;
}
if (beta == 0) {
WriteBasic(mc, nc, c, C, ldc);
return;
}
if (beta == 1 && !relu) {
if (bias == nullptr) {
WriteWithAdd(mc, nc, c, C, ldc);
} else {
WriteWithAddV1(mc, nc, c, C, ldc, bias);
}
return;
}
if (beta == 1 && relu) {
if (bias == nullptr) {
WriteWithAddRelu(mc, nc, c, C, ldc);
} else {
WriteWithAddReluV1(mc, nc, c, C, ldc, bias);
}
return;
}
}
// 分块矩阵乘法
void Gemm::InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
......@@ -1874,6 +2007,63 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
"q10", "q11", "q12", "q13", "q14", "q15");
}
// C = A * B, 8位 int
void Gemm::WriteBasic(int mc, int nc, int *c, int *C, int ldc) {
int nc1 = nc >> 4;
int _nc1 = nc & 15;
int step = sizeof(int) * ldc;
int step1 = sizeof(int) * (NC - (nc1 << 4));
int volatile m = mc;
int *volatile c_ptr, *volatile C_ptr;
int *C0, *c0;
c_ptr = c;
C_ptr = C;
if (nc1 > 0) {
asm volatile(
"subs %[mc], %[mc], #1 \n\t"
"blt end_mc_%= \n\t"
"loop_mc_%=: \n\t"
"mov r6, %[C_ptr] \n\t"
"mov r5, %[nc1] \n\t"
"subs r5, r5, #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c_ptr]]! \n\t"
"vst1.32 {q0, q1}, [r6]! \n\t"
"vld1.32 {q2, q3}, [%[c_ptr]]! \n\t"
"vst1.32 {q2, q3}, [r6]! \n\t"
"subs r5, r5, #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"add %[C_ptr], %[C_ptr], %[step] \n\t"
"add %[c_ptr], %[c_ptr], %[step1] \n\t"
"subs %[mc], %[mc], #1 \n\t"
"bge loop_mc_%= \n\t"
"end_mc_%=: \n\t"
:
: [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1),
[step] "r"(step), [step1] "r"(step1)
: "memory", "r5", "r6", "q0", "q1", "q2", "q3");
}
if (_nc1 != 0) {
for (int i = 0; i < mc; i++) {
C0 = C_ptr + nc1 * 16 + i * ldc;
c0 = c_ptr + nc1 * 16 + i * NC;
for (int j = 0; j < _nc1; j++) {
*C0++ = *c0++;
}
}
}
}
// C = A * B
void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16;
......@@ -1931,9 +2121,14 @@ void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
}
}
// C = alpha * A * B + beta * C
void Gemm::WriteWithAlphaBeta(int mc, int nc, int *c, int *C, int ldc) {}
// C = alpha * A * B + beta * C
void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
// C = A * B + C
void Gemm::WriteWithAdd(int mc, int nc, int *c, int *C, int ldc) {}
// C = A * B + C
void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16;
......@@ -1998,6 +2193,9 @@ void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
}
}
// C = A * B + bias
void Gemm::WriteWithAddV1(int mc, int nc, int *c, int *C, int ldc,
int8_t *bias) {}
// C = A * B + bias
void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {
......@@ -2037,6 +2235,9 @@ void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc,
}
}
// C = A * B + C, relu(C)
void Gemm::WriteWithAddRelu(int mc, int nc, int *c, int *C, int ldc) {}
// C = A * B + C, relu(C)
void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16;
......@@ -2110,6 +2311,9 @@ void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
}
}
}
// C = A * B + bias, relu(C)
void Gemm::WriteWithAddReluV1(int mc, int nc, int *c, int *C, int ldc,
int8_t *bias) {}
// C = A * B + bias, relu(C)
void Gemm::WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
......@@ -2996,6 +3200,69 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
#endif // __ARM_NEON
// 8位 int 矩阵乘法 (m*k与k*n的乘积)
void Gemm::Sgemm(int m, int n, int k, float alpha, const int8_t *A, int lda,
const int8_t *B, int ldb, float beta, int *C, int ldc,
bool relu, int8_t *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
int L2 = 512 * 1024;
KC = k;
MC = L1 / (KC * sizeof(int8_t));
NC = L2 / (KC * sizeof(int8_t));
// make sure MC is multiple of MR, and NC is multiple of NR
if (MC == 0) {
MC = MR;
} else {
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
packedC_int8 = static_cast<int *>(
paddle_mobile::memory::Alloc(sizeof(int) * MC * NC));
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * KC);
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8);
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8);
if (bias == nullptr) {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int8, &C(i, j), ldc, relu, nullptr);
} else {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int8, &C(i, j), ldc, relu, bias + i);
}
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int8);
paddle_mobile::memory::Free(zero_int8);
}
// 32位 float 矩阵乘法
void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
......@@ -3589,6 +3856,125 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
paddle_mobile::memory::Free(zero);
}
void Gemm::AddDot6x8(int k, const int8_t *a, const int8_t *b, int *c, int ldc) {
#if __ARM_NEON
const int8_t *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k >> 1;
int kc2 = k & 1;
int step = sizeof(int) * ldc;
asm volatile(
// q4-q15: save 48 results
"vmov.s8 q4, #0 \n\t"
"vmov.s8 q5, #0 \n\t"
"vmov.s8 q6, #0 \n\t"
"vmov.s8 q7, #0 \n\t"
"vmov.s8 q8, #0 \n\t"
"vmov.s8 q9, #0 \n\t"
"vmov.s8 q10, #0 \n\t"
"vmov.s8 q11, #0 \n\t"
"vmov.s8 q12, #0 \n\t"
"vmov.s8 q13, #0 \n\t"
"vmov.s8 q14, #0 \n\t"
"vmov.s8 q15, #0 \n\t"
"mov r0, #6 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt 1f \n\t"
"0: \n\t"
"vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0
"vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 used
"vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B row1, q1 used
"vmov.s8 q2, #0 \n\t" // q2 used
"vdup.s8 d6, d0[0] \n\t" // q3 used(but d7 is free)
"vmlal.s8 q2, d2, d6 \n\t" // A col00 * B row0
"vdup.s8 d6, d1[0] \n\t"
"vmlal.s8 q2, d3, d6 \n\t" // A col10 * B row1, q3 free
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[1] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vdup.s8 d6, d1[1] \n\t"
"vmlal.s8 q2, d3, d6 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[2] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vdup.s8 d6, d1[2] \n\t"
"vmlal.s8 q2, d3, d6 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[3] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vdup.s8 d6, d1[3] \n\t"
"vmlal.s8 q2, d3, d6 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[4] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vdup.s8 d6, d1[4] \n\t"
"vmlal.s8 q2, d3, d6 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[5] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vdup.s8 d6, d1[5] \n\t"
"vmlal.s8 q2, d3, d6 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"subs %[kc1], %[kc1], #1 \n\t"
"bge 0b \n\t"
"1: \n\t" // odd, last row
"subs %[kc2], %[kc2], #1 \n\t"
"blt 2f \n\t"
"vld1.s8 {d0}, [%[a_ptr]] \n\t"
"vld1.s8 {d1}, [%[b_ptr]] \n\t"
"vdup.s8 d2, d0[0] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vdup.s8 d2, d0[1] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vdup.s8 d2, d0[2] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vdup.s8 d2, d0[3] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vdup.s8 d2, d0[4] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vdup.s8 d2, d0[5] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 4
"2: \n\t"
"vst1.32 {q4, q5}, [%[c]], %[step] \n\t"
"vst1.32 {q6, q7}, [%[c]], %[step] \n\t"
"vst1.32 {q8, q9}, [%[c]], %[step] \n\t"
"vst1.32 {q10, q11}, [%[c]], %[step] \n\t"
"vst1.32 {q12, q13}, [%[c]], %[step] \n\t"
"vst1.32 {q14, q15}, [%[c]] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step)
: "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif
}
void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
#if __ARM_NEON
#if __aarch64__
......@@ -3662,7 +4048,7 @@ void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
b_ptr = b;
int kc1 = k / 8;
int kc2 = k % 8;
int step = 4 * ldc;
int step = sizeof(float) * ldc;
asm volatile(
"pld [%[a_ptr]] \n\t"
"pld [%[a_ptr], #64] \n\t"
......@@ -3866,11 +4252,10 @@ void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step)
: "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
: "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif // __aarch64__
#else
#endif // __ARM_NEON
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint-gcc.h>
#include <string>
#include "common/log.h"
......@@ -79,6 +80,12 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 8位 int
void PackMatrixA_6r(int m, int k, int m_tail, const int8_t *A, int lda,
int8_t *buffer);
void PackMatrixB_8c(int k, int n, int n_tail, const int8_t *B, int ldb,
int8_t *buffer);
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu);
......@@ -96,6 +103,12 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b,
float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
// 8位 int
void InnerKernelWithBias(int mc, int nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int *c, int *C, int ldc,
bool relu, int8_t *bias);
/*
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
......@@ -114,6 +127,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const int8_t *a, const int8_t *b, int *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
......@@ -139,6 +154,20 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *new_scale, float *new_bias);
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias1);
// 8位 int 分块矩阵乘法结果回写
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, int *c, int *C, int ldc);
// C = A * B
void WriteBasic(int mc, int nc, int *c, int *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, int *c, int *C, int ldc);
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, int *c, int *C, int ldc, int8_t *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, int *c, int *C, int ldc);
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int mc, int nc, int *c, int *C, int ldc,
int8_t *bias);
/*
// 向量矩阵乘法结果回写
// C = A * B
......@@ -157,6 +186,11 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *new_bias);
*/
// 8位 int 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const int8_t *A, int lda,
const int8_t *B, int ldb, float beta, int *C, int ldc, bool relu,
int8_t *bias);
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu,
......@@ -190,10 +224,17 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
int KC = 0;
int NC = 0;
// 32位 float
float *packedA;
float *packedB;
float *packedC;
float *zero;
// 8位 int
int8_t *packedA_int8;
int8_t *packedB_int8;
int *packedC_int8;
int8_t *zero_int8;
};
} // namespace math
......
......@@ -254,6 +254,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp)
target_link_libraries(test-gemm-accuracy paddle-mobile)
# gen test
ADD_EXECUTABLE(test-gemm-int8-accuracy common/test_gemm_int8_accuracy.cpp)
target_link_libraries(test-gemm-int8-accuracy paddle-mobile)
# gen test
ADD_EXECUTABLE(test-gemm-perf common/test_gemm_perf.cpp)
target_link_libraries(test-gemm-perf paddle-mobile)
......
......@@ -84,7 +84,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
}
paddle_mobile::operators::math::Gemm gemm;
gemm.SgemmWithBn(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias,
gemm.SgemmWithBn(m, n, k, 1, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias,
nullptr);
int eq = 0;
int neq = 0;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <random>
#include "../test_helper.h"
#include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm.h"
#define a(i, j) a[(i)*lda + (j)]
#define b(i, j) b[(i)*ldb + (j)]
#define c(i, j) c[(i)*ldc + (j)]
#define c1(i, j) c1[(i)*ldc + (j)]
using std::default_random_engine;
using std::uniform_int_distribution;
void print_matirx(int m, int n, int ldc, int32_t *c) {
for (int i = 0; i < m; ++i) {
std::cout << c(i, 0);
for (int j = 1; j < n; ++j) {
std::cout << " | " << c(i, j);
}
std::cout << std::endl;
}
std::cout << std::endl;
}
void print_matirx(int m, int n, int ldc, int8_t *c) {
for (int i = 0; i < m; ++i) {
std::cout << static_cast<int32_t>(c(i, 0));
for (int j = 1; j < n; ++j) {
std::cout << " | " << static_cast<int32_t>(c(i, j));
}
std::cout << std::endl;
}
std::cout << std::endl;
}
int do_sgemm(int m, int n, int k, bool relu, int pr) {
int lda = k;
int ldb = n;
int ldc = n;
default_random_engine e;
uniform_int_distribution<int8_t> pixel(-127, 127);
int8_t *a =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k));
int8_t *b =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n));
int32_t *c =
static_cast<int32_t *>(paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n));
int32_t *c1 =
static_cast<int32_t *>(paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n));
for (int i = 0; i < m * k; ++i) {
a[i] = pixel(e);
}
for (int i = 0; i < k * n; ++i) {
b[i] = pixel(e);
}
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
int32_t r = 0;
for (int p = 0; p < k; p++) {
r += static_cast<int32_t>(a(i, p)) * static_cast<int32_t>(b(p, j));
}
c1(i, j) = r;
}
}
paddle_mobile::operators::math::Gemm gemm;
gemm.Sgemm(m, n, k, 1, a, lda, b, ldb, 0, c, ldc, relu,
nullptr);
int eq = 0;
int neq = 0;
for (int i = 0; i < m * n; ++i) {
if (c[i] == c1[i]) {
++eq;
} else {
++neq;
}
}
if (pr > 0) {
std::cout << "A:" << std::endl;
print_matirx(m, k, lda, a);
std::cout << "B:" << std::endl;
print_matirx(k, n, ldb, b);
std::cout << "C:" << std::endl;
print_matirx(m, n, ldc, c);
std::cout << "C1:" << std::endl;
print_matirx(m, n, ldc, c1);
}
std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu
<< " eq=" << eq << " neq=" << neq << std::endl;
paddle_mobile::memory::Free(a);
paddle_mobile::memory::Free(b);
paddle_mobile::memory::Free(c);
paddle_mobile::memory::Free(c1);
return 0;
}
int main() {
do_sgemm(9, 9, 9, false, 10);
do_sgemm(10, 6, 12, false, 0);
do_sgemm(512, 256, 384, false, 0);
do_sgemm(1366, 768, 256, false, 0);
do_sgemm(1255, 755, 333, false, 0);
do_sgemm(555, 777, 999, false, 0);
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册