From 20654eacb11d492ba0ee2e057f21ebdfd6c31321 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 18 Oct 2018 22:46:13 +0800 Subject: [PATCH] add int8_t gemm and enable MulOp to support int8_t. --- src/framework/operator.cpp | 2 +- src/framework/tensor.h | 4 +- src/operators/kernel/arm/mul_kernel.cpp | 3 + .../kernel/central-arm-func/mul_arm_func.h | 16 +- src/operators/math/gemm.h | 1 - src/operators/math/gemm_int8.cpp | 435 +++++++++++++----- test/operators/test_mul_op.cpp | 134 +++--- 7 files changed, 413 insertions(+), 182 deletions(-) diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index dd865fb27d..21b14dfcac 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -32,7 +32,7 @@ template vector OperatorBase::GetInputKeys() const { auto it = op_input_output_key.find(type_); if (it == op_input_output_key.end()) { - DLOG << type_ << " has no outputs"; + DLOG << type_ << " has no inputs"; return {}; } return it->second.first; diff --git a/src/framework/tensor.h b/src/framework/tensor.h index 909819c145..66ad328fa9 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -338,10 +338,12 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) { for (int i = 0; i < tensor.numel(); i += stride) { if (tensor.type() == typeid(float)) { printer << tensor.data()[i] << " "; + } else if (tensor.type() == typeid(int32_t)) { + printer << tensor.data()[i] << " "; } else if (tensor.type() == typeid(int64_t)) { printer << tensor.data()[i] << " "; } else if (tensor.type() == typeid(int8_t)) { - printer << tensor.data()[i] << " "; + printer << static_cast(tensor.data()[i]) << " "; } } #endif diff --git a/src/operators/kernel/arm/mul_kernel.cpp b/src/operators/kernel/arm/mul_kernel.cpp index aa3ee7077e..4c5867bccd 100644 --- a/src/operators/kernel/arm/mul_kernel.cpp +++ b/src/operators/kernel/arm/mul_kernel.cpp @@ -25,12 +25,15 @@ bool MulKernel::Init(MulParam *param) { return true; } + template <> void MulKernel::Compute(const MulParam ¶m) const { MulCompute(param); param.Out()->set_lod(param.InputX()->lod()); } +template class MulKernel; + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index dd6df54da5..07e634e3be 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -58,7 +58,7 @@ void MulCompute(const MulParam ¶m) { const Tensor *input_x = param.InputX(); const Tensor *input_y = param.InputY(); Tensor *out = param.Out(); - out->mutable_data(); + const Tensor x_matrix = input_x->dims().size() > 2 ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) @@ -71,15 +71,21 @@ void MulCompute(const MulParam ¶m) { if (out_dim.size() != 2) { out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(0)); + if (param.InputX()->type() == typeid(int8_t)) { + out->mutable_data(); + math::matmul(x_matrix, false, y_matrix, false, + static_cast(1), out, static_cast(0)); + + } else { + out->mutable_data(); + math::matmul(x_matrix, false, y_matrix, false, static_cast(1), + out, static_cast(0)); + } if (out_dim.size() != 2) { out->Resize(out_dim); } } -template class MulKernel; - } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 77c3293bf4..b937173dd3 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include #include "common/log.h" diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index a885acc0d2..c52cd2fb29 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -34,119 +34,340 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, const int8_t *a_ptr, *b_ptr; a_ptr = a; b_ptr = b; - int32_t kc1 = k >> 1; - int32_t kc2 = k & 1; + int32_t kc1 = k >> 3; + int32_t kc2 = k & 7; + int32_t kc3 = kc2 >> 1; + int32_t kc4 = kc2 & 1; int32_t step = sizeof(int32_t) * ldc; asm volatile( // q4-q15: save 48 results - "pld [%[a_ptr]] \n\t" - "pld [%[b_ptr]] \n\t" - "vmov.s8 q4, #0 \n\t" - "vmov.s8 q5, #0 \n\t" - "vmov.s8 q6, #0 \n\t" - "vmov.s8 q7, #0 \n\t" - "vmov.s8 q8, #0 \n\t" - "vmov.s8 q9, #0 \n\t" - "vmov.s8 q10, #0 \n\t" - "vmov.s8 q11, #0 \n\t" - "vmov.s8 q12, #0 \n\t" - "vmov.s8 q13, #0 \n\t" - "vmov.s8 q14, #0 \n\t" - "vmov.s8 q15, #0 \n\t" - "mov r0, #6 \n\t" - "subs %[kc1], %[kc1], #1 \n\t" - "blt 1f \n\t" - "0: \n\t" - "pld [%[a_ptr], #64] \n\t" - "pld [%[b_ptr], #64] \n\t" - "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 - "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 used - "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B row1, q1 used - "vmov.s8 q2, #0 \n\t" // q2 used - "vdup.s8 d6, d0[0] \n\t" - "vdup.s8 d7, d1[0] \n\t" // q3 used - "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B row0 - "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B row1, q3 free - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[1] \n\t" - "vdup.s8 d7, d1[1] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[2] \n\t" - "vdup.s8 d7, d1[2] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[3] \n\t" - "vdup.s8 d7, d1[3] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[4] \n\t" - "vdup.s8 d7, d1[4] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[5] \n\t" - "vdup.s8 d7, d1[5] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - - "subs %[kc1], %[kc1], #1 \n\t" - "bge 0b \n\t" - "1: \n\t" // odd, last row - "subs %[kc2], %[kc2], #1 \n\t" - "blt 2f \n\t" - "vld1.s8 {d0}, [%[a_ptr]] \n\t" - "vld1.s8 {d1}, [%[b_ptr]] \n\t" - "vdup.s8 d2, d0[0] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vdup.s8 d2, d0[1] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vdup.s8 d2, d0[2] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vdup.s8 d2, d0[3] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vdup.s8 d2, d0[4] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vdup.s8 d2, d0[5] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 4 - "2: \n\t" - "vst1.32 {q4, q5}, [%[c]], %[step] \n\t" - "vst1.32 {q6, q7}, [%[c]], %[step] \n\t" - "vst1.32 {q8, q9}, [%[c]], %[step] \n\t" - "vst1.32 {q10, q11}, [%[c]], %[step] \n\t" - "vst1.32 {q12, q13}, [%[c]], %[step] \n\t" - "vst1.32 {q14, q15}, [%[c]] \n\t" + "pld [%[a_ptr]] \n\t" + "pld [%[b_ptr]] \n\t" + "vmov.s8 q4, #0 \n\t" + "vmov.s8 q5, #0 \n\t" + "vmov.s8 q6, #0 \n\t" + "vmov.s8 q7, #0 \n\t" + "vmov.s8 q8, #0 \n\t" + "vmov.s8 q9, #0 \n\t" + "vmov.s8 q10, #0 \n\t" + "vmov.s8 q11, #0 \n\t" + "vmov.s8 q12, #0 \n\t" + "vmov.s8 q13, #0 \n\t" + "vmov.s8 q14, #0 \n\t" + "vmov.s8 q15, #0 \n\t" + "mov r0, #6 \n\t" + "subs %[kc1], %[kc1], #1 \n\t" + "blt 1f \n\t" + "0: \n\t" + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 + // used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B + // row1, q1 + // used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B + // row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B + // row1, q3 + // free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 + // used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B + // row1, q1 + // used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B + // row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B + // row1, q3 + // free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 + // used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B + // row1, q1 + // used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B + // row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B + // row1, q3 + // free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 + // used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B + // row1, q1 + // used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B + // row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B + // row1, q3 + // free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "subs %[kc1], %[kc1], #1 \n\t" // last <8 rows + "bge 0b \n\t" + "1: \n\t" + "subs %[kc3], %[kc3], #1 \n\t" + "blt 3f \n\t" + "2: \n\t" + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 + // used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B + // row1, q1 + // used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B + // row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B + // row1, q3 + // free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "subs %[kc3], %[kc3], #1 \n\t" + "bge 2b \n\t" + + "3: \n\t" // odd, last + // row + "subs %[kc4], %[kc4], #1 \n\t" + "blt 4f \n\t" + "vld1.s8 {d0}, [%[a_ptr]] \n\t" + "vld1.s8 {d1}, [%[b_ptr]] \n\t" + "vdup.s8 d2, d0[0] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vdup.s8 d2, d0[1] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vdup.s8 d2, d0[2] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vdup.s8 d2, d0[3] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vdup.s8 d2, d0[4] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vdup.s8 d2, d0[5] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 4 + "4: \n\t" + "vst1.32 {q4, q5}, [%[c]], %[step] \n\t" + "vst1.32 {q6, q7}, [%[c]], %[step] \n\t" + "vst1.32 {q8, q9}, [%[c]], %[step] \n\t" + "vst1.32 {q10, q11}, [%[c]], %[step] \n\t" + "vst1.32 {q12, q13}, [%[c]], %[step] \n\t" + "vst1.32 {q14, q15}, [%[c]] \n\t" : : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), - [kc2] "r"(kc2), [step] "r"(step) + [kc3] "r"(kc3), [kc4] "r"(kc4), [step] "r"(step) : "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); #endif diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 8ebf092689..3080100e70 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -12,80 +12,80 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "../test_helper.h" #include "../test_include.h" #include "operators/mul_op.h" -int main() { - paddle_mobile::Loader loader; - auto program = loader.Load(g_resnet); - PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, - "program file read fail"); - - Executor4Test> - executor(program, "mul"); - - // 1. input_tensors; - vector input_tensors; - - Tensor input1; - auto input1_data = CreateInput(&input1, {3, 2, 1, 1}, 0, 1); - input_tensors.push_back(input1); - Tensor input2; - auto input2_data = CreateInput(&input2, {2, 3}, 0, 1); - input_tensors.push_back(input2); - - // 2. input_names - vector input_names({ - "pool2d_0.tmp_0", - "fc_0.w_0", - }); - - // 3. output_names - vector output_names({"fc_0.tmp_0"}); - - // 4. out_dims; - vector out_ddims; - auto out_ddim = paddle_mobile::framework::make_ddim({3, 3}); - out_ddims.push_back(out_ddim); - - auto output = executor.Predict(input_tensors, input_names, - output_names, out_ddims); - - auto output0_data = output[0]->data(); - - auto dim_1 = input1.numel() / input1.dims()[0]; - DLOG << " input1 : "; - for (int i = 0; i < input1.dims()[0]; ++i) { - for (int j = 0; j < dim_1; ++j) { - DLOGF("%f ", input1_data[i * dim_1 + j]); +#define a(i, j) a[(i)*lda + (j)] +#define b(i, j) b[(i)*ldb + (j)] +#define c(i, j) c[(i)*ldc + (j)] + +namespace paddle_mobile { +using framework::AttributeMap; +using framework::DDim; +using framework::Scope; +using framework::make_ddim; +template +int TestMulOP() { + int32_t m = 1024; + int32_t n = 1024; + int32_t k = 1024; + int32_t lda = k; + int32_t ldb = n; + int32_t ldc = n; + DDim inputA_shape = make_ddim({m, k}); + DDim inputB_shape = make_ddim({k, n}); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"inputA"}); + inputs["Y"] = std::vector({"inputB"}); + outputs["Out"] = std::vector({"output"}); + + auto inputA_var = scope.get()->Var("inputA"); + auto inputA = inputA_var->template GetMutable(); + SetupTensor(inputA, inputA_shape, -127, 127); + auto inputB_var = scope.get()->Var("inputB"); + auto inputB = inputB_var->template GetMutable(); + SetupTensor(inputB, inputB_shape, -127, 127); + + auto output_var = scope.get()->Var("output"); + AttributeMap attrs; + attrs["x_num_col_dims"].Set(1); + attrs["y_num_col_dims"].Set(1); + auto *op = + new operators::MulOp("mul", inputs, outputs, attrs, scope); + op->InferShape(); + op->Run(); + auto output = output_var->template Get(); + const O *output_data = output->data(); + // compare + O *c = static_cast(memory::Alloc(sizeof(O) * m * n)); + I *a = inputA->data(); + I *b = inputB->data(); + for (int32_t i = 0; i < m; ++i) { + for (int32_t j = 0; j < n; ++j) { + O r = 0; + for (int32_t p = 0; p < k; p++) { + r += static_cast(a(i, p)) * static_cast(b(p, j)); + } + c(i, j) = r; } - DLOGF("\n"); } - auto dim_2 = input2.numel() / input2.dims()[0]; - DLOG << " input2 : "; - for (int i = 0; i < input2.dims()[0]; ++i) { - for (int j = 0; j < dim_2; ++j) { - DLOGF("%f ", input2_data[i * dim_2 + j]); - } - DLOGF("\n"); - } - - auto dim_output0 = output[0]->numel() / output[0]->dims()[0]; - DLOG << " output : "; - for (int i = 0; i < output[0]->dims()[0]; ++i) { - for (int j = 0; j < dim_output0; ++j) { - DLOGF("%f ", output0_data[i * dim_2 + j]); - } - DLOGF("\n"); + for (int32_t i = 0; i < m * n; ++i) { + PADDLE_MOBILE_ENFORCE( + output_data[i] == c[i], "output[%d] = %d, output_cmp[%d] = %d", i, + static_cast(output_data[i]), i, static_cast(c[i])); } + DLOG << "Run MulOp successfully!"; + delete op; + return 0; +} +} // namespace paddle_mobile - /// output (3,3) - DLOG << "output memory size : " << output[0]->memory_size(); - DLOG << "output numel : " << output[0]->numel(); - - DLOG << input1_data[0] << " x " << input2_data[0] << " + " << input1_data[1] - << " x " << input2_data[0 + 3] << " = " << output0_data[0]; +int main() { + paddle_mobile::TestMulOP(); + paddle_mobile::TestMulOP(); return 0; } -- GitLab