diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index dd865fb27d4345f16ddca8005463986787d681be..21b14dfcac682e7d310dcf4e8c47afaa0fb68fb3 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -32,7 +32,7 @@ template vector OperatorBase::GetInputKeys() const { auto it = op_input_output_key.find(type_); if (it == op_input_output_key.end()) { - DLOG << type_ << " has no outputs"; + DLOG << type_ << " has no inputs"; return {}; } return it->second.first; diff --git a/src/framework/tensor.h b/src/framework/tensor.h index 909819c145e2a5388ec42d2609f82929ed337d7d..66ad328fa98aa7d36ba33dc4929567b2ff79884e 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -338,10 +338,12 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) { for (int i = 0; i < tensor.numel(); i += stride) { if (tensor.type() == typeid(float)) { printer << tensor.data()[i] << " "; + } else if (tensor.type() == typeid(int32_t)) { + printer << tensor.data()[i] << " "; } else if (tensor.type() == typeid(int64_t)) { printer << tensor.data()[i] << " "; } else if (tensor.type() == typeid(int8_t)) { - printer << tensor.data()[i] << " "; + printer << static_cast(tensor.data()[i]) << " "; } } #endif diff --git a/src/operators/kernel/arm/mul_kernel.cpp b/src/operators/kernel/arm/mul_kernel.cpp index aa3ee7077eb7db440c8493eae5b95f03a42196a4..276281f963e449af9d55f7c5ca58ef5da17e6f93 100644 --- a/src/operators/kernel/arm/mul_kernel.cpp +++ b/src/operators/kernel/arm/mul_kernel.cpp @@ -31,6 +31,8 @@ void MulKernel::Compute(const MulParam ¶m) const { param.Out()->set_lod(param.InputX()->lod()); } +template class MulKernel; + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index dd6df54da5a81c2c4d1030103b6bb9811a54246a..07e634e3be9648520357871d91d6677aec6b5c0e 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -58,7 +58,7 @@ void MulCompute(const MulParam ¶m) { const Tensor *input_x = param.InputX(); const Tensor *input_y = param.InputY(); Tensor *out = param.Out(); - out->mutable_data(); + const Tensor x_matrix = input_x->dims().size() > 2 ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) @@ -71,15 +71,21 @@ void MulCompute(const MulParam ¶m) { if (out_dim.size() != 2) { out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(0)); + if (param.InputX()->type() == typeid(int8_t)) { + out->mutable_data(); + math::matmul(x_matrix, false, y_matrix, false, + static_cast(1), out, static_cast(0)); + + } else { + out->mutable_data(); + math::matmul(x_matrix, false, y_matrix, false, static_cast(1), + out, static_cast(0)); + } if (out_dim.size() != 2) { out->Resize(out_dim); } } -template class MulKernel; - } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 1fcfc5f98a5279cc4a93da596edbd63c693bd488..2990f7a0f8d4712a3dc3c429d9b57e5aa3809325 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -3662,7 +3662,7 @@ void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { b_ptr = b; int kc1 = k / 8; int kc2 = k % 8; - int step = 4 * ldc; + int step = sizeof(float) * ldc; asm volatile( "pld [%[a_ptr]] \n\t" "pld [%[a_ptr], #64] \n\t" @@ -3866,11 +3866,10 @@ void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { : : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), [kc2] "r"(kc2), [step] "r"(step) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); #endif // __aarch64__ -#else #endif // __ARM_NEON } diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index d7f5b2249ad20f4e2d242ce68b6069ae71a23e28..adc6924d8ad273012a9b44677f8ad1a29bc37787 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -96,6 +96,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b, float *c, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); + /* // 向量矩阵乘法 (M = 1) void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, @@ -139,6 +140,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *new_scale, float *new_bias); void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, float *new_scale, float *new_bias, float *bias1); + /* // 向量矩阵乘法结果回写 // C = A * B @@ -185,15 +187,63 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, const float *B, int ldb, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); + // 8 bits function cluster begins + // 8 bits int small block inner product + void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc); + + // 8 bits int inner product + void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, + const int8_t *a, const int8_t *b, int8_t beta, + int32_t *c, int32_t *C, int32_t ldc, bool relu, + int8_t *bias); + + // 8 bits int pack function + void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); + void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer); + + // 8 bits int matrix product + void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C, + int32_t ldc, bool relu, int8_t *bias); + + // 8 bits int write back + // C = alpha * A * B + beta * C + void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc); + // C = A * B + void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc); + // C = A * B + C + void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc); + // C = A * B + bias + void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias); + // C = A * B + C, relu(C) + void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc); + // C = A * B + bias, relu(C) + void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias); + private: int MC = 0; int KC = 0; int NC = 0; + // 32位 float float *packedA; float *packedB; float *packedC; float *zero; + + // 8 bits int + int8_t *packedA_int8; + int8_t *packedB_int8; + int32_t *packedC_int8; + int8_t *zero_int8; }; } // namespace math diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bd5286dbcb5c871d5d327875b836ad9777c270bf --- /dev/null +++ b/src/operators/math/gemm_int8.cpp @@ -0,0 +1,717 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "common/log.h" +#include "memory/t_malloc.h" +#include "operators/math/gemm.h" +#if __ARM_NEON +#include +#endif +#ifdef _OPENMP +#include +#endif + +namespace paddle_mobile { +namespace operators { +namespace math { + +// 8 bits int small block inner product +void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc) { +#if __ARM_NEON + const int8_t *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int32_t kc1 = k >> 3; + int32_t kc2 = k & 7; + int32_t kc3 = kc2 >> 2; + int32_t kc4 = kc2 & 3; + int32_t kc5 = kc4 >> 1; + int32_t kc6 = kc4 & 1; + int32_t step = sizeof(int32_t) * ldc; + asm volatile( + // q4-q15: save 48 results + "pld [%[a_ptr]] \n\t" + "pld [%[b_ptr]] \n\t" + "pld [%[b_ptr], #64] \n\t" + "vmov.s8 q4, #0 \n\t" + "vmov.s8 q5, #0 \n\t" + "vmov.s8 q6, #0 \n\t" + "vmov.s8 q7, #0 \n\t" + "vmov.s8 q8, #0 \n\t" + "vmov.s8 q9, #0 \n\t" + "vmov.s8 q10, #0 \n\t" + "vmov.s8 q11, #0 \n\t" + "vmov.s8 q12, #0 \n\t" + "vmov.s8 q13, #0 \n\t" + "vmov.s8 q14, #0 \n\t" + "vmov.s8 q15, #0 \n\t" + "mov r0, #12 \n\t" + "subs %[kc1], %[kc1], #1 \n\t" + "blt 1f \n\t" + "0: \n\t" + "pld [%[a_ptr], #64] \n\t" + "pld [%[b_ptr], #128] \n\t" + "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols, q0 used, + // 1/2 q3 used + "vmov.s8 q2, #0 \n\t" // q2 used + "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, + // q1 + "vdup.s8 d3, d0[0] \n\t" // q3 used // used + "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 + "vdup.s8 d3, d0[6] \n\t" // q3 used + "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, + // q3 free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[1] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d0[7] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[2] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[0] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[3] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[1] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[4] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[2] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[5] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[3] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, + // q1 + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d3, d1[4] \n\t" // q3 used // used + "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 + "vdup.s8 d3, d2[2] \n\t" // q3 used + "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, + // q3 free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d1[5] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[3] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d1[6] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[4] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d1[7] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[5] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d2[0] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[6] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d2[1] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[7] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols, q0 used, + // 1/2 q3 used + "vmov.s8 q2, #0 \n\t" // q2 used + "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, + // q1 + "vdup.s8 d3, d0[0] \n\t" // q3 used // used + "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 + "vdup.s8 d3, d0[6] \n\t" // q3 used + "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, + // q3 free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[1] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d0[7] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[2] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[0] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[3] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[1] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[4] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[2] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[5] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[3] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, + // q1 + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d3, d1[4] \n\t" // q3 used // used + "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 + "vdup.s8 d3, d2[2] \n\t" // q3 used + "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, + // q3 free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d1[5] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[3] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d1[6] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[4] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d1[7] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[5] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d2[0] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[6] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d2[1] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[7] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 0b \n\t" + "1: \n\t" // last <8 rows + "subs %[kc3], %[kc3], #1 \n\t" + "blt 2f \n\t" + "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" + "vmov.s8 q2, #0 \n\t" + "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" + "vdup.s8 d3, d0[0] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d0[6] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[1] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d0[7] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[2] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[0] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[3] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[1] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[4] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[2] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d0[5] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d1[3] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d1[4] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[2] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d1[5] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[3] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d1[6] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[4] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d1[7] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[5] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d2[0] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[6] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d3, d2[1] \n\t" + "vmlal.s8 q2, d6, d3 \n\t" + "vdup.s8 d3, d2[7] \n\t" + "vmlal.s8 q2, d7, d3 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "2: \n\t" // last <4 rows + "subs %[kc5], %[kc5], #1 \n\t" + "blt 3f \n\t" + "vld1.s8 {d0, d1}, [%[a_ptr]], r0 \n\t" + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[0] \n\t" + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" + "vdup.s8 d7, d0[6] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d0[7] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[0] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "3: \n\t" // last <2 rows + "subs %[kc6], %[kc6], #1 \n\t" + "blt 4f \n\t" + "vld1.s8 {d0}, [%[a_ptr]] \n\t" + "vld1.s8 {d1}, [%[b_ptr]] \n\t" + "vdup.s8 d2, d0[0] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vdup.s8 d2, d0[1] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vdup.s8 d2, d0[2] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vdup.s8 d2, d0[3] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vdup.s8 d2, d0[4] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vdup.s8 d2, d0[5] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 4 + "4: \n\t" + "vst1.32 {q4, q5}, [%[c]], %[step] \n\t" + "vst1.32 {q6, q7}, [%[c]], %[step] \n\t" + "vst1.32 {q8, q9}, [%[c]], %[step] \n\t" + "vst1.32 {q10, q11}, [%[c]], %[step] \n\t" + "vst1.32 {q12, q13}, [%[c]], %[step] \n\t" + "vst1.32 {q14, q15}, [%[c]] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [kc3] "r"(kc3), [kc5] "r"(kc5), [kc6] "r"(kc6), [step] "r"(step) + : "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif +} + +// 8 bits int inner product +void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, + const int8_t *a, const int8_t *b, int8_t beta, + int32_t *c, int32_t *C, int32_t ldc, bool relu, + int8_t *bias) { +#pragma omp parallel for + for (int32_t j = 0; j < nc; j += NR) { + for (int32_t i = 0; i < mc; i += MR) { + AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + } + } + if (alpha != 1) { + WriteWithAlphaBeta(mc, nc, c, C, ldc); + return; + } + if (beta == 0) { + WriteBasic(mc, nc, c, C, ldc); + return; + } + if (beta == 1 && !relu) { + if (bias == nullptr) { + WriteWithAdd(mc, nc, c, C, ldc); + } else { + WriteWithAddV1(mc, nc, c, C, ldc, bias); + } + return; + } + if (beta == 1 && relu) { + if (bias == nullptr) { + WriteWithAddRelu(mc, nc, c, C, ldc); + } else { + WriteWithAddReluV1(mc, nc, c, C, ldc, bias); + } + return; + } +} + +// 8 bits int PackMatrixA +void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer) { + const int32_t i_length = m - m_tail; + for (int32_t i = 0; i < i_length; i += MR) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + const int8_t *a4 = A + (i + 4) * lda; + const int8_t *a5 = A + (i + 5) * lda; + int8_t *local_buffer = buffer + i * k; + for (int32_t j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + const int8_t *a4 = a0 + 4 * lda; + const int8_t *a5 = a0 + 5 * lda; + int8_t *local_buffer = buffer + i_length * k; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + case 4: + a4 = zero_int8; + case 5: + a5 = zero_int8; + break; + default: + break; + } + for (int32_t j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } +} + +// 8 bits int PackMatrixB +void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + for (int32_t j = 0; j < j_length; j += NR) { + int8_t *local_buffer = buffer + j * k; + for (int32_t i = 0; i < k; ++i) { + const int8_t *b0 = &B(i, j); +#if __ARM_NEON + asm volatile( + // "pld [%[b0]] \n\t" + "vld1.s8 {d0}, [%[b0]] \n\t" + "vst1.s8 {d0}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "q0"); +#else + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; +#endif // __ARM_NEON + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * k; + for (int32_t i = 0; i < k; ++i) { + const int8_t *b0 = &B(i, j_length); + for (int32_t j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int32_t j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; + } + } + } +} + +// 8 bits int matrix product (m*k x k*n) +void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, + int32_t *C, int32_t ldc, bool relu, int8_t *bias) { + // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) + // L2 cache is 0.5~4 Mib (Contex-A72 cluster) + int32_t L1 = 32 * 1024; + int32_t L2 = 512 * 1024; + + KC = k; + MC = L1 / (KC * sizeof(int8_t)); + NC = L2 / (KC * sizeof(int8_t)); + + // make sure MC is multiple of MR, and NC is multiple of NR + if (MC == 0) { + MC = MR; + } else { + int32_t mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + } + // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; + if (NC == 0) { + NC = NR; + } else { + int32_t nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + } + // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); + packedC_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); + zero_int8 = + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); + + memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); + int32_t mc, nc; + for (int32_t j = 0; j < n; j += NC) { + nc = s_min(n - j, NC); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); + for (int32_t i = 0; i < m; i += MC) { + mc = s_min(m - i, MC); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8); + if (bias == nullptr) { + InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int8, &C(i, j), ldc, relu, nullptr); + } else { + InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int8, &C(i, j), ldc, relu, bias + i); + } + } + } + + paddle_mobile::memory::Free(packedA_int8); + paddle_mobile::memory::Free(packedB_int8); + paddle_mobile::memory::Free(packedC_int8); + paddle_mobile::memory::Free(zero_int8); +} + +// 8 bits int write back +// C = alpha * A * B + beta * C +void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) {} +// C = A * B, 8位 int32_t +void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) { + int32_t nc1 = nc >> 4; + int32_t _nc1 = nc & 15; + int32_t step = sizeof(int32_t) * ldc; + int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 4)); + int32_t volatile m = mc; + + int32_t *volatile c_ptr, *volatile C_ptr; + int32_t *C0, *c0; + c_ptr = c; + C_ptr = C; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" + + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vst1.32 {q0, q1}, [r6]! \n\t" + + "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" + "vst1.32 {q2, q3}, [r6]! \n\t" + + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" + + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + [step] "r"(step), [step1] "r"(step1) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); + } + + if (_nc1 != 0) { + for (int32_t i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 16 + i * ldc; + c0 = c_ptr + nc1 * 16 + i * NC; + for (int32_t j = 0; j < _nc1; j++) { + *C0++ = *c0++; + } + } + } +} + +// C = A * B + C +void Gemm::WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) {} + +// C = A * B + bias +void Gemm::WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias) {} +// C = A * B + C, relu(C) +void Gemm::WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) {} + +// C = A * B + bias, relu(C) +void Gemm::WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias) {} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index 9d39f89b04ebcef93fa9d122d629bdf6f4586c66..fc4c385add5ccf30ebe42695fb616e41deb1a827 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -135,7 +135,7 @@ template struct ClearTensor { void operator()(framework::Tensor *tensor) { auto size = tensor->numel(); - auto *tensor_data = tensor->data(); + auto *tensor_data = tensor->data(); memset((void *)tensor_data, 0, sizeof(T) * size); // NOLINT } }; @@ -151,9 +151,9 @@ struct RowwiseAdd { PADDLE_MOBILE_ENFORCE((output->dims() == in_dims), "output->dims() must be equal to in_dims."); - auto *input_data = input.data(); - auto *out_data = output->data(); - auto *vec_data = vector.data(); + auto *input_data = input.data(); + auto *out_data = output->data(); + auto *vec_data = vector.data(); for (int64_t i = 0; i < in_dims[0]; ++i) { for (int64_t j = 0; j < size; ++j) { out_data[i * size + j] = input_data[i * size + j] + vec_data[j]; diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index de19e3df2ab69c8ac490b09af2852bf2fa806c64..b70dfb43ba11400e555365485f2a632c854279ac 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -25,7 +25,7 @@ template void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu = false, - float *bias = nullptr); + T *bias = nullptr); template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..70677223d12ded2da07ab53bc371f1e8da9fe293 --- /dev/null +++ b/src/operators/math/math_function_int8.cpp @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "operators/math/gemm.h" +#include "operators/math/math_function.h" + +namespace paddle_mobile { +namespace operators { +namespace math { +template <> +void matmul(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, + int8_t alpha, framework::Tensor *matrix_out, int8_t beta, + bool relu, int8_t *bias) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_MOBILE_ENFORCE( + dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + int32_t M = dim_out[0]; + int32_t N = dim_out[1]; + int32_t K = (!trans_a) ? dim_a[1] : dim_a[0]; + Gemm gemm; + + if (trans_a) { + int32_t numel = matrix_a.numel(); + int32_t m = matrix_a.dims()[0]; + int32_t n = matrix_a.dims()[1]; + int8_t *tmp = (int8_t *)(matrix_a.data()); // NOLINT + int8_t *a = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * numel)); + int32_t index = 0; + for (int32_t j = 0; j < n; j++) { + for (int32_t i = 0; i < m; i++) { + a[index++] = tmp[i * n + j]; + } + } + + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } else { + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), N, + relu, bias); + } +} +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1893491f1163c42784af1213b6581ae7817f86b2..a4191954a82928b7e6cd7ea79073cc2f0142f256 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -266,6 +266,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp) target_link_libraries(test-gemm-accuracy paddle-mobile) + # gen test + ADD_EXECUTABLE(test-gemm-int8-accuracy common/test_gemm_int8_accuracy.cpp) + target_link_libraries(test-gemm-int8-accuracy paddle-mobile) + # gen test ADD_EXECUTABLE(test-gemm-perf common/test_gemm_perf.cpp) target_link_libraries(test-gemm-perf paddle-mobile) diff --git a/test/common/test_gemm_accuracy.cpp b/test/common/test_gemm_accuracy.cpp index 0967094f6895d35784a9c06344e3473e66fcd370..2a2505a86b1abab5fe6fd8e0b9905ce7ae78f292 100644 --- a/test/common/test_gemm_accuracy.cpp +++ b/test/common/test_gemm_accuracy.cpp @@ -84,7 +84,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { } paddle_mobile::operators::math::Gemm gemm; - gemm.SgemmWithBn(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias, + gemm.SgemmWithBn(m, n, k, 1, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias, nullptr); int eq = 0; int neq = 0; diff --git a/test/common/test_gemm_int8_accuracy.cpp b/test/common/test_gemm_int8_accuracy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80ddd40e121c81032c903955bd7116cf52695569 --- /dev/null +++ b/test/common/test_gemm_int8_accuracy.cpp @@ -0,0 +1,131 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "../test_helper.h" +#include "common/log.h" +#include "memory/t_malloc.h" +#include "operators/math/gemm.h" + +#define a(i, j) a[(i)*lda + (j)] +#define b(i, j) b[(i)*ldb + (j)] +#define c(i, j) c[(i)*ldc + (j)] +#define c1(i, j) c1[(i)*ldc + (j)] + +using std::default_random_engine; +using std::uniform_int_distribution; + +void print_matirx(int m, int n, int ldc, int32_t *c) { + for (int i = 0; i < m; ++i) { + std::cout << c(i, 0); + for (int j = 1; j < n; ++j) { + std::cout << " | " << c(i, j); + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +void print_matirx(int m, int n, int ldc, int8_t *c) { + for (int i = 0; i < m; ++i) { + std::cout << static_cast(c(i, 0)); + for (int j = 1; j < n; ++j) { + std::cout << " | " << static_cast(c(i, j)); + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +int do_sgemm(int m, int n, int k, bool relu, int pr) { + int lda = k; + int ldb = n; + int ldc = n; + default_random_engine e; + uniform_int_distribution pixel(-127, 127); + int8_t *a = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k)); + int8_t *b = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n)); + int32_t *c = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n)); + int32_t *c1 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n)); + + for (int i = 0; i < m * k; ++i) { + a[i] = pixel(e); + } + for (int i = 0; i < k * n; ++i) { + b[i] = pixel(e); + } + + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + int32_t r = 0; + for (int p = 0; p < k; p++) { + r += static_cast(a(i, p)) * static_cast(b(p, j)); + } + c1(i, j) = r; + } + } + + paddle_mobile::operators::math::Gemm gemm; + gemm.Sgemm(m, n, k, static_cast(1), a, lda, b, ldb, + static_cast(0), c, ldc, relu, nullptr); + int eq = 0; + int neq = 0; + for (int i = 0; i < m * n; ++i) { + if (c[i] == c1[i]) { + ++eq; + } else { + ++neq; + } + } + + if (pr > 0) { + std::cout << "A:" << std::endl; + print_matirx(m, k, lda, a); + std::cout << "B:" << std::endl; + print_matirx(k, n, ldb, b); + std::cout << "C:" << std::endl; + print_matirx(m, n, ldc, c); + std::cout << "C1:" << std::endl; + print_matirx(m, n, ldc, c1); + } + + std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu + << " eq=" << eq << " neq=" << neq << std::endl; + + paddle_mobile::memory::Free(a); + paddle_mobile::memory::Free(b); + paddle_mobile::memory::Free(c); + paddle_mobile::memory::Free(c1); + + return 0; +} + +int main() { + do_sgemm(9, 9, 9, false, 10); + do_sgemm(10, 6, 12, false, 0); + do_sgemm(512, 256, 384, false, 0); + do_sgemm(1366, 768, 256, false, 0); + do_sgemm(1255, 755, 333, false, 0); + do_sgemm(555, 777, 999, false, 0); + do_sgemm(1024, 1024, 1024, false, 0); + + return 0; +} diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 386c09d71a3d5709842991bffd2e8ea039edc940..89f0012ae8effaab383719c1b85748c24eb2bf73 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -28,13 +28,11 @@ limitations under the License. */ int main() { paddle_mobile::PaddleMobile paddle_mobile; - paddle_mobile.SetThreadNum(4); - Tensor aa, bb, cc, scale, bias; + paddle_mobile.SetThreadNum(1); + Tensor aa, bb, cc; auto aaptr = aa.mutable_data({m, k}); auto bbptr = bb.mutable_data({k, n}); auto ccptr = cc.mutable_data({m, n}); - auto scaleptr = scale.mutable_data({m}); - auto biasptr = bias.mutable_data({m}); for (int i = 0; i < m * k; ++i) { aaptr[i] = 2; @@ -45,23 +43,55 @@ int main() { for (int i = 0; i < m * n; ++i) { ccptr[i] = 2; } - for (int i = 0; i < m; ++i) { - scaleptr[i] = 1; - biasptr[i] = 0; + + Tensor aa_int8, bb_int8, cc_int8; + auto aaptr_int8 = aa_int8.mutable_data({m, k}); + auto bbptr_int8 = bb_int8.mutable_data({k, n}); + auto ccptr_int8 = cc_int8.mutable_data({m, n}); + + for (int i = 0; i < m * k; ++i) { + aaptr_int8[i] = static_cast(2); + } + for (int i = 0; i < k * n; ++i) { + bbptr_int8[i] = static_cast(2); + } + for (int i = 0; i < m * n; ++i) { + ccptr_int8[i] = static_cast(2); } - auto time1 = time(); + // float + // warm-up 10 times for (int j = 0; j < 10; ++j) { paddle_mobile::operators::math::matmul( aa, false, bb, false, static_cast(1), &cc, static_cast(0), - false, biasptr); + false, nullptr); + } - // paddle_mobile::operators::math::matmulWithBn( - // aa, false, bb, false, static_cast(1), &cc, - // static_cast(0), true, &scale, &bias, 0); + auto time1 = time(); + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa, false, bb, false, static_cast(1), &cc, static_cast(0), + false, nullptr); } auto time2 = time(); - std::cout << "gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; + std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; + + // int8_t + // warm-up 10 times + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + static_cast(0), false, nullptr); + } + + auto time3 = time(); + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + static_cast(0), false, nullptr); + } + auto time4 = time(); + std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; return 0; } diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 8ebf0926890497c0ed622b69f163a9f6f5c8612b..678add6dcedd22e788e0bd2df64a8eba59ad8514 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -12,80 +12,89 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include "../test_helper.h" #include "../test_include.h" #include "operators/mul_op.h" -int main() { - paddle_mobile::Loader loader; - auto program = loader.Load(g_resnet); - PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, - "program file read fail"); - - Executor4Test> - executor(program, "mul"); - - // 1. input_tensors; - vector input_tensors; - - Tensor input1; - auto input1_data = CreateInput(&input1, {3, 2, 1, 1}, 0, 1); - input_tensors.push_back(input1); - Tensor input2; - auto input2_data = CreateInput(&input2, {2, 3}, 0, 1); - input_tensors.push_back(input2); - - // 2. input_names - vector input_names({ - "pool2d_0.tmp_0", - "fc_0.w_0", - }); - - // 3. output_names - vector output_names({"fc_0.tmp_0"}); - - // 4. out_dims; - vector out_ddims; - auto out_ddim = paddle_mobile::framework::make_ddim({3, 3}); - out_ddims.push_back(out_ddim); - - auto output = executor.Predict(input_tensors, input_names, - output_names, out_ddims); - - auto output0_data = output[0]->data(); - - auto dim_1 = input1.numel() / input1.dims()[0]; - DLOG << " input1 : "; - for (int i = 0; i < input1.dims()[0]; ++i) { - for (int j = 0; j < dim_1; ++j) { - DLOGF("%f ", input1_data[i * dim_1 + j]); - } - DLOGF("\n"); - } - - auto dim_2 = input2.numel() / input2.dims()[0]; - DLOG << " input2 : "; - for (int i = 0; i < input2.dims()[0]; ++i) { - for (int j = 0; j < dim_2; ++j) { - DLOGF("%f ", input2_data[i * dim_2 + j]); +#define a(i, j) a[(i)*lda + (j)] +#define b(i, j) b[(i)*ldb + (j)] +#define c(i, j) c[(i)*ldc + (j)] + +namespace paddle_mobile { +using framework::AttributeMap; +using framework::DDim; +using framework::Scope; +using framework::make_ddim; +template +int TestMulOP() { + int32_t m = 1024; + int32_t n = 1024; + int32_t k = 1024; + int32_t lda = k; + int32_t ldb = n; + int32_t ldc = n; + DDim inputA_shape = make_ddim({m, k}); + DDim inputB_shape = make_ddim({k, n}); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"inputA"}); + inputs["Y"] = std::vector({"inputB"}); + outputs["Out"] = std::vector({"output"}); + + auto inputA_var = scope.get()->Var("inputA"); + auto inputA = inputA_var->template GetMutable(); + SetupTensor(inputA, inputA_shape, -127, 127); + auto inputB_var = scope.get()->Var("inputB"); + auto inputB = inputB_var->template GetMutable(); + SetupTensor(inputB, inputB_shape, -127, 127); + + auto output_var = scope.get()->Var("output"); + AttributeMap attrs; + attrs["x_num_col_dims"].Set(1); + attrs["y_num_col_dims"].Set(1); + auto *op = + new operators::MulOp("mul", inputs, outputs, attrs, scope); + op->InferShape(); + op->Run(); + auto output = output_var->template Get(); + const O *output_data = output->data(); + // compare + O *c = static_cast(memory::Alloc(sizeof(O) * m * n)); + I *a = inputA->data(); + I *b = inputB->data(); + for (int32_t i = 0; i < m; ++i) { + for (int32_t j = 0; j < n; ++j) { + O r = 0; + for (int32_t p = 0; p < k; p++) { + r += static_cast(a(i, p)) * static_cast(b(p, j)); + } + c(i, j) = r; } - DLOGF("\n"); } - auto dim_output0 = output[0]->numel() / output[0]->dims()[0]; - DLOG << " output : "; - for (int i = 0; i < output[0]->dims()[0]; ++i) { - for (int j = 0; j < dim_output0; ++j) { - DLOGF("%f ", output0_data[i * dim_2 + j]); + int32_t eq = 0; + int32_t neq = 0; + for (int32_t i = 0; i < m * n; ++i) { + PADDLE_MOBILE_ENFORCE( + output_data[i] == c[i], "output[%d] = %d, output_cmp[%d] = %d", i, + static_cast(output_data[i]), i, static_cast(c[i])); + if (static_cast(output_data[i] == c[i])) { + ++eq; + } else { + ++neq; } - DLOGF("\n"); } + DLOG << "mnk=" << m << " " << n << " " << k << " eq=" << eq + << " neq=" << neq; + delete op; + return 0; +} +} // namespace paddle_mobile - /// output (3,3) - DLOG << "output memory size : " << output[0]->memory_size(); - DLOG << "output numel : " << output[0]->numel(); - - DLOG << input1_data[0] << " x " << input2_data[0] << " + " << input1_data[1] - << " x " << input2_data[0 + 3] << " = " << output0_data[0]; +int main() { + paddle_mobile::TestMulOP(); + paddle_mobile::TestMulOP(); return 0; }