From 5ac1e63c53b6d60d425566953d09f1da2f773454 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Sun, 21 Oct 2018 04:43:17 +0000 Subject: [PATCH] Support 1x1 and 7x7 conv, fix quant scale, we can run the googlenet with int8 now --- src/framework/operator.cpp | 2 +- .../kernel/arm/dequantize_kernel.cpp | 3 +- src/operators/kernel/arm/quantize_kernel.cpp | 11 +- .../kernel/central-arm-func/conv_arm_func.h | 78 +- .../depthwise_conv_arm_func.h | 2 +- .../kernel/central-arm-func/mul_arm_func.h | 16 +- src/operators/math/conv3x3_arm_int8.cpp | 34 +- src/operators/math/conv_arm_int8.h | 4 + src/operators/math/gemm.cpp | 7 +- src/operators/math/gemm.h | 50 ++ src/operators/math/gemm_int8.cpp | 652 ++++++++++++++ src/operators/math/im2col.cpp | 804 +++++++++--------- src/operators/math/math_function.cpp | 8 +- src/operators/math/math_function.h | 3 +- src/operators/math/math_function_int8.cpp | 64 ++ src/operators/math/vol2col.cpp | 61 +- src/operators/op_param.h | 6 +- test/CMakeLists.txt | 4 + test/common/test_gemm_accuracy.cpp | 2 +- test/common/test_gemm_int8_accuracy.cpp | 131 +++ test/common/test_gemm_perf.cpp | 56 +- test/net/test_googlenet.cpp | 18 +- test/operators/test_dequantize_op.cpp | 2 +- test/operators/test_int8_conv_op.cpp | 92 +- test/operators/test_mul_op.cpp | 134 +-- test/operators/test_quantize_op.cpp | 25 +- 26 files changed, 1608 insertions(+), 661 deletions(-) create mode 100644 src/operators/math/gemm_int8.cpp create mode 100644 src/operators/math/math_function_int8.cpp create mode 100644 test/common/test_gemm_int8_accuracy.cpp diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index dd865fb27d..21b14dfcac 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -32,7 +32,7 @@ template vector OperatorBase::GetInputKeys() const { auto it = op_input_output_key.find(type_); if (it == op_input_output_key.end()) { - DLOG << type_ << " has no outputs"; + DLOG << type_ << " has no inputs"; return {}; } return it->second.first; diff --git a/src/operators/kernel/arm/dequantize_kernel.cpp b/src/operators/kernel/arm/dequantize_kernel.cpp index 3033c16c74..935ce470a8 100644 --- a/src/operators/kernel/arm/dequantize_kernel.cpp +++ b/src/operators/kernel/arm/dequantize_kernel.cpp @@ -38,7 +38,8 @@ void DequantizeKernel::Compute( const int32_t *x = input->data(); float *y = output->mutable_data(); size_t size = output->numel(); - float scale = 1.f / (activation_scale * weight_scale); + // float scale = 1.f / (activation_scale * weight_scale); + float scale = activation_scale / weight_scale; #if defined(__ARM_NEON__) || defined(__ARM_NEON) size_t loop = size >> 4; size_t remain = size & 0xF; diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index e2c8efc299..fe8256a1ea 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -280,17 +280,18 @@ void QuantizeKernel::Compute( } max_abs = std::max(max_abs, 1e-6f); // only support int8 currently - float online_scale = 127 / max_abs; - param.online_scale_->mutable_data()[0] = online_scale; + float scale = 127 / max_abs; + param.online_scale_->mutable_data()[0] = max_abs; switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: - quantize_round_to_even(input, online_scale, output); + quantize_round_to_even(input, scale, output); break; case ROUND_NEAREST_TOWARDS_ZERO: - quantize_round_to_zero(input, online_scale, output); + quantize_round_to_zero(input, scale, output); break; case ROUND_NEAREST_AWAY_ZERO: - quantize_round_to_nearest(input, online_scale, output); + quantize_round_to_nearest(input, scale, output); + break; default: LOG(kLOG_ERROR) << "round type is not supported."; break; diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index a2d887aa7a..f80a8f9441 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -28,15 +28,15 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { +template inline void ConvBasic(const ConvParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); Tensor *output = param.Output(); - output->mutable_data(); int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); + const std::vector strides = param.Strides(); + const std::vector paddings = param.Paddings(); + const std::vector dilations = param.Dilations(); const int batch_size = static_cast(input->dims()[0]); @@ -60,7 +60,7 @@ inline void ConvBasic(const ConvParam ¶m) { Tensor col; Tensor col_matrix; if (is_expand) { - col.mutable_data(col_shape); + col.mutable_data(col_shape); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } @@ -79,8 +79,8 @@ inline void ConvBasic(const ConvParam ¶m) { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output->dims()[1]) / groups; - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); @@ -99,6 +99,7 @@ inline void ConvBasic(const ConvParam ¶m) { std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &col); + } else if (data_dim == 3U) { // vol2col vol2col(in_slice, dilations, strides, paddings, &col); @@ -107,7 +108,8 @@ inline void ConvBasic(const ConvParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, + + math::matmul(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, static_cast(0)); } @@ -126,42 +128,41 @@ inline void ConvCompute_int8(const ConvParam ¶m) { const Tensor *input = param.Input(); Tensor *filter = param.Filter(); Tensor *output = param.Output(); - output->mutable_data(); int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); + const std::vector &strides = param.Strides(); + const std::vector &paddings = param.Paddings(); + const std::vector &dilations = param.Dilations(); int kernel_h = filter->dims()[2]; int kernel_w = filter->dims()[3]; - const int batch_size = static_cast(input->dims()[0]); - math::PadFunctor pad; - - Tensor input_pad; - for (int i = 0; i < batch_size; ++i) { - Tensor in_batch = input->Slice(i, i + 1); - Tensor out_batch = output->Slice(i, i + 1); - if (paddings[0] == 0 && paddings[1] == 0) { - input_pad = in_batch; - } else { - framework::DDim pad_shape = in_batch.dims(); - pad_shape[2] += 2 * paddings[0]; - pad_shape[3] += 2 * paddings[1]; - input_pad.mutable_data(pad_shape); - pad(in_batch, paddings[0], paddings[1], &input_pad); - } + output->mutable_data(); - if (strides[1] == strides[0] && strides[1] < 6 && kernel_h == kernel_w && - kernel_h < 8 && groups == 1 && dilations[0] == dilations[1] && - dilations[1] == 1) { - ConvFunc conv_func = conv_funcs_table[kernel_h - 1][strides[0] - 1]; - if (conv_func) { - conv_func(input_pad, *filter, &out_batch); + ConvFunc conv_func = 0; + if (strides[1] == strides[0] && strides[1] < 6 && kernel_h == kernel_w && + kernel_h < 8 && groups == 1 && dilations[0] == dilations[1] && + dilations[1] == 1) { + conv_func = conv_funcs_table[kernel_h - 1][strides[0] - 1]; + } + if (conv_func) { + int batch_size = input->dims()[0]; + math::PadFunctor pad; + + Tensor input_pad; + for (int i = 0; i < batch_size; ++i) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + if (paddings[0] == 0 && paddings[1] == 0) { + input_pad = in_batch; } else { - // TODO(hjchen2) + framework::DDim pad_shape = in_batch.dims(); + pad_shape[2] += 2 * paddings[0]; + pad_shape[3] += 2 * paddings[1]; + input_pad.mutable_data(pad_shape); + pad(in_batch, paddings[0], paddings[1], &input_pad); } - } else { - // TODO(hjchen2) + conv_func(input_pad, *filter, &out_batch); } + } else { + ConvBasic(param); } } @@ -170,6 +171,7 @@ void ConvCompute(const ConvParam ¶m) { if (param.Input()->type() == typeid(int8_t)) { ConvCompute_int8(param); } else { + param.Output()->mutable_data(); if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && @@ -183,7 +185,7 @@ void ConvCompute(const ConvParam ¶m) { math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), param.Filter(), nullptr, param.Output(), false); } else { - ConvBasic(param); + ConvBasic(param); } } } diff --git a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h index 2a1afb3cf6..ff5d5d4b2a 100644 --- a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h @@ -44,7 +44,7 @@ void DepthwiseConvCompute(const ConvParam ¶m) { Bias, false); } else { - ConvBasic(param); + ConvBasic(param); } } diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index dd6df54da5..07e634e3be 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -58,7 +58,7 @@ void MulCompute(const MulParam ¶m) { const Tensor *input_x = param.InputX(); const Tensor *input_y = param.InputY(); Tensor *out = param.Out(); - out->mutable_data(); + const Tensor x_matrix = input_x->dims().size() > 2 ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) @@ -71,15 +71,21 @@ void MulCompute(const MulParam ¶m) { if (out_dim.size() != 2) { out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(0)); + if (param.InputX()->type() == typeid(int8_t)) { + out->mutable_data(); + math::matmul(x_matrix, false, y_matrix, false, + static_cast(1), out, static_cast(0)); + + } else { + out->mutable_data(); + math::matmul(x_matrix, false, y_matrix, false, static_cast(1), + out, static_cast(0)); + } if (out_dim.size() != 2) { out->Resize(out_dim); } } -template class MulKernel; - } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/conv3x3_arm_int8.cpp b/src/operators/math/conv3x3_arm_int8.cpp index a61406aa65..f8a4e9f409 100644 --- a/src/operators/math/conv3x3_arm_int8.cpp +++ b/src/operators/math/conv3x3_arm_int8.cpp @@ -112,15 +112,15 @@ void conv3x3s1_int8(const framework::Tensor& input, "vmull.s8 q7, d4, d7 \n" "vmlal.s8 q6, d5, d8 \n" "vaddw.s16 q12, q12, d12 \n" - "vaddw.s16 q12, q12, d14 \n" "vaddw.s16 q13, q13, d13 \n" + "vaddw.s16 q12, q12, d14 \n" "vaddw.s16 q13, q13, d15 \n" "vmull.s8 q6, d2, d9 \n" "vmull.s8 q7, d4, d10 \n" "vmlal.s8 q6, d5, d11 \n" "vaddw.s16 q14, q14, d12 \n" - "vaddw.s16 q14, q14, d14 \n" "vaddw.s16 q15, q15, d13 \n" + "vaddw.s16 q14, q14, d14 \n" "vaddw.s16 q15, q15, d15 \n" "vld1.8 {d2-d3}, [%[r2]] \n" // r2 @@ -139,8 +139,8 @@ void conv3x3s1_int8(const framework::Tensor& input, "vmull.s8 q7, d4, d10 \n" "vmlal.s8 q6, d5, d11 \n" "vaddw.s16 q10, q10, d12 \n" - "vaddw.s16 q10, q10, d14 \n" "vaddw.s16 q11, q11, d13 \n" + "vaddw.s16 q10, q10, d14 \n" "vaddw.s16 q11, q11, d15 \n" "vdup.s8 d6, d0[6] \n" @@ -153,21 +153,23 @@ void conv3x3s1_int8(const framework::Tensor& input, "vmull.s8 q7, d4, d7 \n" "vmlal.s8 q6, d5, d8 \n" "vaddw.s16 q12, q12, d12 \n" - "vaddw.s16 q12, q12, d14 \n" "vaddw.s16 q13, q13, d13 \n" + "vaddw.s16 q12, q12, d14 \n" "vaddw.s16 q13, q13, d15 \n" + + "vld1.32 {d12-d15}, [%[output0]] \n" + "vadd.s32 q6, q6, q12 \n" + "vadd.s32 q7, q7, q13 \n" + "vst1.32 {d12-d15}, [%[output0]]! \n" + "vmull.s8 q6, d2, d9 \n" "vmull.s8 q7, d4, d10 \n" "vmlal.s8 q6, d5, d11 \n" "vaddw.s16 q14, q14, d12 \n" - "vaddw.s16 q14, q14, d14 \n" "vaddw.s16 q15, q15, d13 \n" + "vaddw.s16 q14, q14, d14 \n" "vaddw.s16 q15, q15, d15 \n" - "vld1.32 {d12-d15}, [%[output0]] \n" - "vadd.s32 q6, q6, q12 \n" - "vadd.s32 q7, q7, q13 \n" - "vst1.32 {d12-d15}, [%[output0]]! \n" "vld1.32 {d12-d15}, [%[output1]] \n" "vadd.s32 q6, q6, q14 \n" "vadd.s32 q7, q7, q15 \n" @@ -182,21 +184,23 @@ void conv3x3s1_int8(const framework::Tensor& input, "vmull.s8 q7, d4, d7 \n" "vmlal.s8 q6, d5, d8 \n" "vaddw.s16 q8, q8, d12 \n" + "vaddw.s16 q9, q9, d15 \n" "vaddw.s16 q8, q8, d14 \n" "vaddw.s16 q9, q9, d13 \n" - "vaddw.s16 q9, q9, d15 \n" + + "vld1.32 {d12-d15}, [%[output0n]] \n" + "vadd.s32 q6, q6, q8 \n" + "vadd.s32 q7, q7, q9 \n" + "vst1.32 {d12-d15}, [%[output0n]]! \n" + "vmull.s8 q6, d2, d9 \n" "vmull.s8 q7, d4, d10 \n" "vmlal.s8 q6, d5, d11 \n" "vaddw.s16 q10, q10, d12 \n" + "vaddw.s16 q11, q11, d15 \n" "vaddw.s16 q10, q10, d14 \n" "vaddw.s16 q11, q11, d13 \n" - "vaddw.s16 q11, q11, d15 \n" - "vld1.32 {d12-d15}, [%[output0n]] \n" - "vadd.s32 q6, q6, q8 \n" - "vadd.s32 q7, q7, q9 \n" - "vst1.32 {d12-d15}, [%[output0n]]! \n" "vld1.32 {d12-d15}, [%[output1n]] \n" "vadd.s32 q6, q6, q10 \n" "vadd.s32 q7, q7, q11 \n" diff --git a/src/operators/math/conv_arm_int8.h b/src/operators/math/conv_arm_int8.h index 4e59307158..98843e6158 100644 --- a/src/operators/math/conv_arm_int8.h +++ b/src/operators/math/conv_arm_int8.h @@ -24,6 +24,10 @@ namespace operators { void conv3x3s1_int8(const framework::Tensor& input, const framework::Tensor& weight, framework::Tensor* output); +void conv3x3s1_int8_4c(const framework::Tensor& input, + const framework::Tensor& weight, + framework::Tensor* output); + void conv5x5s1_int8(const framework::Tensor& input, const framework::Tensor& weight, framework::Tensor* output); diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 1fcfc5f98a..2990f7a0f8 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -3662,7 +3662,7 @@ void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { b_ptr = b; int kc1 = k / 8; int kc2 = k % 8; - int step = 4 * ldc; + int step = sizeof(float) * ldc; asm volatile( "pld [%[a_ptr]] \n\t" "pld [%[a_ptr], #64] \n\t" @@ -3866,11 +3866,10 @@ void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { : : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), [kc2] "r"(kc2), [step] "r"(step) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); #endif // __aarch64__ -#else #endif // __ARM_NEON } diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index d7f5b2249a..b937173dd3 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -96,6 +96,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b, float *c, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); + /* // 向量矩阵乘法 (M = 1) void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, @@ -139,6 +140,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *new_scale, float *new_bias); void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, float *new_scale, float *new_bias, float *bias1); + /* // 向量矩阵乘法结果回写 // C = A * B @@ -185,15 +187,63 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, const float *B, int ldb, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); + /************************ 8 bit function cluster ************************/ + // 8 bit int small block inner product + void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc); + + // 8 bit int inner product + void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, + const int8_t *a, const int8_t *b, int8_t beta, + int32_t *c, int32_t *C, int32_t ldc, bool relu, + int8_t *bias); + + // 8 bit int pack function + void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); + void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer); + + // 8 bit int matrix product + void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C, + int32_t ldc, bool relu, int8_t *bias); + + // 8 bit int write back + // C = alpha * A * B + beta * C + void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc); + // C = A * B + void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc); + // C = A * B + C + void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc); + // C = A * B + bias + void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias); + // C = A * B + C, relu(C) + void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc); + // C = A * B + bias, relu(C) + void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias); + private: int MC = 0; int KC = 0; int NC = 0; + // 32位 float float *packedA; float *packedB; float *packedC; float *zero; + + // 8 bit int + int8_t *packedA_int8; + int8_t *packedB_int8; + int32_t *packedC_int8; + int8_t *zero_int8; }; } // namespace math diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp new file mode 100644 index 0000000000..c52cd2fb29 --- /dev/null +++ b/src/operators/math/gemm_int8.cpp @@ -0,0 +1,652 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "common/log.h" +#include "memory/t_malloc.h" +#include "operators/math/gemm.h" +#if __ARM_NEON +#include +#endif +#ifdef _OPENMP +#include +#endif + +namespace paddle_mobile { +namespace operators { +namespace math { + +// 8 bit int small block inner product +void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc) { +#if __ARM_NEON + const int8_t *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int32_t kc1 = k >> 3; + int32_t kc2 = k & 7; + int32_t kc3 = kc2 >> 1; + int32_t kc4 = kc2 & 1; + int32_t step = sizeof(int32_t) * ldc; + asm volatile( + // q4-q15: save 48 results + "pld [%[a_ptr]] \n\t" + "pld [%[b_ptr]] \n\t" + "vmov.s8 q4, #0 \n\t" + "vmov.s8 q5, #0 \n\t" + "vmov.s8 q6, #0 \n\t" + "vmov.s8 q7, #0 \n\t" + "vmov.s8 q8, #0 \n\t" + "vmov.s8 q9, #0 \n\t" + "vmov.s8 q10, #0 \n\t" + "vmov.s8 q11, #0 \n\t" + "vmov.s8 q12, #0 \n\t" + "vmov.s8 q13, #0 \n\t" + "vmov.s8 q14, #0 \n\t" + "vmov.s8 q15, #0 \n\t" + "mov r0, #6 \n\t" + "subs %[kc1], %[kc1], #1 \n\t" + "blt 1f \n\t" + "0: \n\t" + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 + // used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B + // row1, q1 + // used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B + // row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B + // row1, q3 + // free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 + // used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B + // row1, q1 + // used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B + // row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B + // row1, q3 + // free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 + // used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B + // row1, q1 + // used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B + // row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B + // row1, q3 + // free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 + // used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B + // row1, q1 + // used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B + // row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B + // row1, q3 + // free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "subs %[kc1], %[kc1], #1 \n\t" // last <8 rows + "bge 0b \n\t" + "1: \n\t" + "subs %[kc3], %[kc3], #1 \n\t" + "blt 3f \n\t" + "2: \n\t" + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 + // used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B + // row1, q1 + // used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B + // row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B + // row1, q3 + // free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0. \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "subs %[kc3], %[kc3], #1 \n\t" + "bge 2b \n\t" + + "3: \n\t" // odd, last + // row + "subs %[kc4], %[kc4], #1 \n\t" + "blt 4f \n\t" + "vld1.s8 {d0}, [%[a_ptr]] \n\t" + "vld1.s8 {d1}, [%[b_ptr]] \n\t" + "vdup.s8 d2, d0[0] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vdup.s8 d2, d0[1] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vdup.s8 d2, d0[2] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vdup.s8 d2, d0[3] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vdup.s8 d2, d0[4] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vdup.s8 d2, d0[5] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 4 + "4: \n\t" + "vst1.32 {q4, q5}, [%[c]], %[step] \n\t" + "vst1.32 {q6, q7}, [%[c]], %[step] \n\t" + "vst1.32 {q8, q9}, [%[c]], %[step] \n\t" + "vst1.32 {q10, q11}, [%[c]], %[step] \n\t" + "vst1.32 {q12, q13}, [%[c]], %[step] \n\t" + "vst1.32 {q14, q15}, [%[c]] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [kc3] "r"(kc3), [kc4] "r"(kc4), [step] "r"(step) + : "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif +} + +// 8 bit int inner product +void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, + const int8_t *a, const int8_t *b, int8_t beta, + int32_t *c, int32_t *C, int32_t ldc, bool relu, + int8_t *bias) { +#pragma omp parallel for + for (int32_t j = 0; j < nc; j += NR) { + for (int32_t i = 0; i < mc; i += MR) { + AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + } + } + if (alpha != 1) { + WriteWithAlphaBeta(mc, nc, c, C, ldc); + return; + } + if (beta == 0) { + WriteBasic(mc, nc, c, C, ldc); + return; + } + if (beta == 1 && !relu) { + if (bias == nullptr) { + WriteWithAdd(mc, nc, c, C, ldc); + } else { + WriteWithAddV1(mc, nc, c, C, ldc, bias); + } + return; + } + if (beta == 1 && relu) { + if (bias == nullptr) { + WriteWithAddRelu(mc, nc, c, C, ldc); + } else { + WriteWithAddReluV1(mc, nc, c, C, ldc, bias); + } + return; + } +} + +// 8 bit int PackMatrixA +void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer) { + const int32_t i_length = m - m_tail; + for (int32_t i = 0; i < i_length; i += MR) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + const int8_t *a4 = A + (i + 4) * lda; + const int8_t *a5 = A + (i + 5) * lda; + int8_t *local_buffer = buffer + i * k; + for (int32_t j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + const int8_t *a4 = a0 + 4 * lda; + const int8_t *a5 = a0 + 5 * lda; + int8_t *local_buffer = buffer + i_length * k; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + case 4: + a4 = zero_int8; + case 5: + a5 = zero_int8; + break; + default: + break; + } + for (int32_t j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } +} + +// 8 bit int PackMatrixB +void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + for (int32_t j = 0; j < j_length; j += NR) { + int8_t *local_buffer = buffer + j * k; + for (int32_t i = 0; i < k; ++i) { + const int8_t *b0 = &B(i, j); +#if __ARM_NEON + asm volatile( + // "pld [%[b0]] \n\t" + "vld1.s8 {d0}, [%[b0]] \n\t" + "vst1.s8 {d0}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "q0"); +#else + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; +#endif // __ARM_NEON + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * k; + for (int32_t i = 0; i < k; ++i) { + const int8_t *b0 = &B(i, j_length); + for (int32_t j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int32_t j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; + } + } + } +} + +// 8 bit int matrix product (m*k x k*n) +void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, + int32_t *C, int32_t ldc, bool relu, int8_t *bias) { + // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) + // L2 cache is 0.5~4 Mib (Contex-A72 cluster) + int32_t L1 = 32 * 1024; + int32_t L2 = 512 * 1024; + + KC = k; + MC = L1 / (KC * sizeof(int8_t)); + NC = L2 / (KC * sizeof(int8_t)); + + // make sure MC is multiple of MR, and NC is multiple of NR + if (MC == 0) { + MC = MR; + } else { + int32_t mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + } + // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; + if (NC == 0) { + NC = NR; + } else { + int32_t nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + } + // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); + packedC_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); + zero_int8 = + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); + + memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); + int32_t mc, nc; + for (int32_t j = 0; j < n; j += NC) { + nc = s_min(n - j, NC); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); + for (int32_t i = 0; i < m; i += MC) { + mc = s_min(m - i, MC); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8); + if (bias == nullptr) { + InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int8, &C(i, j), ldc, relu, nullptr); + } else { + InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int8, &C(i, j), ldc, relu, bias + i); + } + } + } + + paddle_mobile::memory::Free(packedA_int8); + paddle_mobile::memory::Free(packedB_int8); + paddle_mobile::memory::Free(packedC_int8); + paddle_mobile::memory::Free(zero_int8); +} + +// 8 bit int write back +// C = alpha * A * B + beta * C +void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) {} +// C = A * B, 8位 int32_t +void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) { + int32_t nc1 = nc >> 4; + int32_t _nc1 = nc & 15; + int32_t step = sizeof(int32_t) * ldc; + int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 4)); + int32_t volatile m = mc; + + int32_t *volatile c_ptr, *volatile C_ptr; + int32_t *C0, *c0; + c_ptr = c; + C_ptr = C; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" + + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vst1.32 {q0, q1}, [r6]! \n\t" + + "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" + "vst1.32 {q2, q3}, [r6]! \n\t" + + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" + + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + [step] "r"(step), [step1] "r"(step1) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); + } + + if (_nc1 != 0) { + for (int32_t i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 16 + i * ldc; + c0 = c_ptr + nc1 * 16 + i * NC; + for (int32_t j = 0; j < _nc1; j++) { + *C0++ = *c0++; + } + } + } +} + +// C = A * B + C +void Gemm::WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) {} + +// C = A * B + bias +void Gemm::WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias) {} +// C = A * B + C, relu(C) +void Gemm::WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) {} + +// C = A * B + bias, relu(C) +void Gemm::WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias) {} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/im2col.cpp b/src/operators/math/im2col.cpp index 090ccdf24e..502e29a7a9 100644 --- a/src/operators/math/im2col.cpp +++ b/src/operators/math/im2col.cpp @@ -28,91 +28,240 @@ namespace math { * [input_channels, filter_height, filter_width, output_height, * output_width] */ -template -class Im2ColFunctor { - public: - void operator()(const framework::Tensor &im, const std::vector &dilation, - const std::vector &stride, - const std::vector &padding, framework::Tensor *col) { - // PADDLE_ENFORCE(im.dims().size() == 3); - // PADDLE_ENFORCE(col->dims().size() == 5); +template <> +void Im2ColFunctor::operator()( + const framework::Tensor &im, const std::vector &dilation, + const std::vector &stride, const std::vector &padding, + framework::Tensor *col) { + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; + + int channels_col = im_channels * filter_height * filter_width; + const float *im_data = im.data(); + float *col_data = col->data(); +#if __ARM_NEON + const int osize = col_height; + const int isize = im_height; + bool pad1 = padding[0] > 0; + bool pad2 = + (pad1 && padding[1] && + (((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0)); + int fill = isize % 2; + if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 && + dilation[0] == 1 && im_height > 2) { + for (int c = 0; c < im_channels; ++c) { + int oosize = osize * osize; + int nk4 = osize / 4; + int mk4 = osize % 4; + + float *col0 = col_data + 0 * oosize + 2 * osize + 2; + float *col1 = col_data + 1 * oosize + 2 * osize + 1; + float *col2 = col_data + 2 * oosize + 2 * osize; + + float *col3 = col_data + 3 * oosize + osize + 2; + float *col4 = col_data + 4 * oosize + osize + 1; + float *col5 = col_data + 5 * oosize + osize; + + float *col6 = col_data + 6 * oosize + 2; + float *col7 = col_data + 7 * oosize + 1; + float *col8 = col_data + 8 * oosize; + + float32x4_t im1; + const float *im_tmp_data = im_data + osize + 1; + + int rrsize = oosize - osize - 1; + int nr4 = rrsize / 4; + int mr4 = rrsize % 4; + for (int i = 0; i < nr4; ++i) { + im1 = vld1q_f32(im_tmp_data); + vst1q_f32(col0, im1); + vst1q_f32(col1, im1); + vst1q_f32(col2, im1); + vst1q_f32(col3, im1); + vst1q_f32(col4, im1); + vst1q_f32(col5, im1); + vst1q_f32(col6, im1); + vst1q_f32(col7, im1); + vst1q_f32(col8, im1); + + col0 += 4; + col1 += 4; + col2 += 4; + col3 += 4; + col4 += 4; + col5 += 4; + col6 += 4; + col7 += 4; + col8 += 4; + + im_tmp_data += 4; + } + for (int i = 0; i < mr4; ++i) { + *col0 = *im_tmp_data; + *col1 = *im_tmp_data; + *col2 = *im_tmp_data; + *col3 = *im_tmp_data; + *col4 = *im_tmp_data; + *col5 = *im_tmp_data; + *col6 = *im_tmp_data; + *col7 = *im_tmp_data; + *col8 = *im_tmp_data; + + col0++; + col1++; + col2++; + col3++; + col4++; + col5++; + col6++; + col7++; + col8++; + + im_tmp_data++; + } - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; - int filter_height = col->dims()[1]; - int filter_width = col->dims()[2]; - int col_height = col->dims()[3]; - int col_width = col->dims()[4]; - - // PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - // - - // ((dilation[0] * (filter_height - 1) - // + 1))) / - // stride[0] + - // 1, - // col_height, - // "Output_height and - // padding(padding_up, padding_down) - // are " "inconsistent."); - // PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - // - - // ((dilation[1] * (filter_width - 1) - // + 1))) / - // stride[1] + - // 1, - // col_width, - // "Output_height and - // padding(padding_up, padding_down) - // are " "inconsistent."); + im_tmp_data = im_data + 1; + col0 = col_data + 0 * oosize + osize + 2; + col1 = col_data + 1 * oosize + osize + 1; + col2 = col_data + 2 * oosize + osize; + + col3 = col_data + 3 * oosize + 2; + col4 = col_data + 4 * oosize + 1; + col5 = col_data + 5 * oosize; + + for (int i = 0; i < nk4; i++) { + im1 = vld1q_f32(im_tmp_data); + vst1q_f32(col0, im1); + vst1q_f32(col1, im1); + vst1q_f32(col2, im1); + vst1q_f32(col3, im1); + vst1q_f32(col4, im1); + vst1q_f32(col5, im1); + + col0 += 4; + col1 += 4; + col2 += 4; + col3 += 4; + col4 += 4; + col5 += 4; + im_tmp_data += 4; + } - int channels_col = im_channels * filter_height * filter_width; - const T *im_data = im.data(); - T *col_data = col->data(); -#if __ARM_NEON - const int osize = col_height; - const int isize = im_height; - bool pad1 = padding[0] > 0; - bool pad2 = - (pad1 && padding[1] && - (((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0)); - int fill = isize % 2; - if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 && - dilation[0] == 1 && im_height > 2) { - for (int c = 0; c < im_channels; ++c) { - int oosize = osize * osize; - int nk4 = osize / 4; - int mk4 = osize % 4; - - float *col0 = col_data + 0 * oosize + 2 * osize + 2; - float *col1 = col_data + 1 * oosize + 2 * osize + 1; - float *col2 = col_data + 2 * oosize + 2 * osize; - - float *col3 = col_data + 3 * oosize + osize + 2; - float *col4 = col_data + 4 * oosize + osize + 1; - float *col5 = col_data + 5 * oosize + osize; - - float *col6 = col_data + 6 * oosize + 2; - float *col7 = col_data + 7 * oosize + 1; - float *col8 = col_data + 8 * oosize; - - float32x4_t im1; - const float *im_tmp_data = im_data + osize + 1; - - int rrsize = oosize - osize - 1; - int nr4 = rrsize / 4; - int mr4 = rrsize % 4; - for (int i = 0; i < nr4; ++i) { - im1 = vld1q_f32(im_tmp_data); - vst1q_f32(col0, im1); - vst1q_f32(col1, im1); - vst1q_f32(col2, im1); - vst1q_f32(col3, im1); - vst1q_f32(col4, im1); - vst1q_f32(col5, im1); - vst1q_f32(col6, im1); - vst1q_f32(col7, im1); - vst1q_f32(col8, im1); + for (int i = 0; i < mk4; i++) { + *col0 = *im_tmp_data; + *col1 = *im_tmp_data; + *col2 = *im_tmp_data; + *col3 = *im_tmp_data; + *col4 = *im_tmp_data; + *col5 = *im_tmp_data; + col0++; + col1++; + col2++; + col3++; + col4++; + col5++; + + im_tmp_data++; + } + + // fill 0 1 11; + for (int i = 0; i < osize; ++i) { + col_data[0 * oosize + i * osize] = 0.0; + col_data[3 * oosize + i * osize] = 0.0; + col_data[6 * oosize + i * osize] = 0.0; + + col_data[2 * oosize + osize - 1 + i * osize] = 0.0; + col_data[5 * oosize + osize - 1 + i * osize] = 0.0; + col_data[8 * oosize + osize - 1 + i * osize] = 0.0; + } + + col_data[0 * oosize + osize + 1] = im_data[0]; + col_data[3 * oosize + 1] = im_data[0]; + col_data[6 * oosize + 1] = im_data[osize]; + + col_data[1 * oosize + osize] = im_data[0]; + col_data[4 * oosize] = im_data[0]; + col_data[7 * oosize] = im_data[osize]; + + float32x4_t zero4; + zero4 = vdupq_n_f32(0.0); + auto col_z0 = col_data; + auto col_z1 = col_data + oosize; + auto col_z2 = col_data + 2 * oosize; + auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); + auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); + auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); + + for (int i = 0; i < nk4; ++i) { + vst1q_f32(col_z0, zero4); + vst1q_f32(col_z1, zero4); + vst1q_f32(col_z2, zero4); + vst1q_f32(col_z6, zero4); + vst1q_f32(col_z7, zero4); + vst1q_f32(col_z8, zero4); + + col_z0 += 4; + col_z1 += 4; + col_z2 += 4; + col_z6 += 4; + col_z7 += 4; + col_z8 += 4; + } + + for (int i = 0; i < mk4; ++i) { + col_z0[i] = 0.0; + col_z1[i] = 0.0; + col_z2[i] = 0.0; + col_z6[i] = 0.0; + col_z7[i] = 0.0; + col_z8[i] = 0.0; + } + col_data += 9 * oosize; + im_data += isize * isize; + } + } else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 && + im_height > 2) { + for (int c = 0; c < im_channels; ++c) { + int oosize = osize * osize; + int nk4 = osize / 4; + int mk4 = osize % 4; + + // 3 2 3 1 0 1 3 2 3 + float *col0 = col_data + 0 * oosize + osize + 1; + float *col1 = col_data + 1 * oosize + osize; + float *col2 = col_data + 2 * oosize + osize; + + float *col3 = col_data + 3 * oosize + 1; + float *col4 = col_data + 4 * oosize; + float *col5 = col_data + 5 * oosize; + + float *col6 = col_data + 6 * oosize + 1; + float *col7 = col_data + 7 * oosize; + float *col8 = col_data + 8 * oosize; + + float32x4x2_t im01; + float32x4x2_t im23; + const float *im_tmp_data0 = im_data; + const float *im_tmp_data2 = im_data + isize; + + for (int j = 0; j < osize; ++j) { + for (int i = 0; i < nk4; ++i) { + im01 = vld2q_f32(im_tmp_data0); + im23 = vld2q_f32(im_tmp_data2); + vst1q_f32(col0, im23.val[1]); + vst1q_f32(col1, im23.val[0]); + vst1q_f32(col2, im23.val[1]); + vst1q_f32(col3, im01.val[1]); + vst1q_f32(col4, im01.val[0]); + vst1q_f32(col5, im01.val[1]); + vst1q_f32(col6, im23.val[1]); + vst1q_f32(col7, im23.val[0]); + vst1q_f32(col8, im23.val[1]); col0 += 4; col1 += 4; @@ -124,18 +273,21 @@ class Im2ColFunctor { col7 += 4; col8 += 4; - im_tmp_data += 4; + im_tmp_data0 += 8; + im_tmp_data2 += 8; } - for (int i = 0; i < mr4; ++i) { - *col0 = *im_tmp_data; - *col1 = *im_tmp_data; - *col2 = *im_tmp_data; - *col3 = *im_tmp_data; - *col4 = *im_tmp_data; - *col5 = *im_tmp_data; - *col6 = *im_tmp_data; - *col7 = *im_tmp_data; - *col8 = *im_tmp_data; + const float *im_tmp_data1 = im_tmp_data0 + 1; + const float *im_tmp_data3 = im_tmp_data2 + 1; + for (int i = 0; i < mk4; ++i) { + *col0 = *im_tmp_data3; + *col1 = *im_tmp_data2; + *col2 = *im_tmp_data3; + *col3 = *im_tmp_data1; + *col4 = *im_tmp_data0; + *col5 = *im_tmp_data1; + *col6 = *im_tmp_data3; + *col7 = *im_tmp_data2; + *col8 = *im_tmp_data3; col0++; col1++; @@ -146,271 +298,72 @@ class Im2ColFunctor { col6++; col7++; col8++; - - im_tmp_data++; - } - - im_tmp_data = im_data + 1; - col0 = col_data + 0 * oosize + osize + 2; - col1 = col_data + 1 * oosize + osize + 1; - col2 = col_data + 2 * oosize + osize; - - col3 = col_data + 3 * oosize + 2; - col4 = col_data + 4 * oosize + 1; - col5 = col_data + 5 * oosize; - - for (int i = 0; i < nk4; i++) { - im1 = vld1q_f32(im_tmp_data); - vst1q_f32(col0, im1); - vst1q_f32(col1, im1); - vst1q_f32(col2, im1); - vst1q_f32(col3, im1); - vst1q_f32(col4, im1); - vst1q_f32(col5, im1); - - col0 += 4; - col1 += 4; - col2 += 4; - col3 += 4; - col4 += 4; - col5 += 4; - im_tmp_data += 4; + im_tmp_data0 += 2; + im_tmp_data1 += 2; + im_tmp_data2 += 2; + im_tmp_data3 += 2; } - - for (int i = 0; i < mk4; i++) { - *col0 = *im_tmp_data; - *col1 = *im_tmp_data; - *col2 = *im_tmp_data; - *col3 = *im_tmp_data; - *col4 = *im_tmp_data; - *col5 = *im_tmp_data; - col0++; - col1++; - col2++; - col3++; - col4++; - col5++; - - im_tmp_data++; - } - - // fill 0 1 11; - for (int i = 0; i < osize; ++i) { - col_data[0 * oosize + i * osize] = 0.0; - col_data[3 * oosize + i * osize] = 0.0; - col_data[6 * oosize + i * osize] = 0.0; - + im_tmp_data0 += (isize - fill); + im_tmp_data2 += (isize - fill); + } + for (int i = 0; i < osize; ++i) { + col_data[0 * oosize + i * osize] = 0.0; + col_data[3 * oosize + i * osize] = 0.0; + col_data[6 * oosize + i * osize] = 0.0; + if (pad2) { col_data[2 * oosize + osize - 1 + i * osize] = 0.0; col_data[5 * oosize + osize - 1 + i * osize] = 0.0; col_data[8 * oosize + osize - 1 + i * osize] = 0.0; } - - col_data[0 * oosize + osize + 1] = im_data[0]; - col_data[3 * oosize + 1] = im_data[0]; - col_data[6 * oosize + 1] = im_data[osize]; - - col_data[1 * oosize + osize] = im_data[0]; - col_data[4 * oosize] = im_data[0]; - col_data[7 * oosize] = im_data[osize]; - - float32x4_t zero4; - zero4 = vdupq_n_f32(0.0); - auto col_z0 = col_data; - auto col_z1 = col_data + oosize; - auto col_z2 = col_data + 2 * oosize; - auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); - auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); - auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); - - for (int i = 0; i < nk4; ++i) { - vst1q_f32(col_z0, zero4); - vst1q_f32(col_z1, zero4); - vst1q_f32(col_z2, zero4); + } + float32x4_t zero4; + zero4 = vdupq_n_f32(0.0); + auto col_z0 = col_data; + auto col_z1 = col_data + oosize; + auto col_z2 = col_data + 2 * oosize; + auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); + auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); + auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); + + for (int i = 0; i < nk4; ++i) { + vst1q_f32(col_z0, zero4); + vst1q_f32(col_z1, zero4); + vst1q_f32(col_z2, zero4); + if (pad2) { vst1q_f32(col_z6, zero4); vst1q_f32(col_z7, zero4); vst1q_f32(col_z8, zero4); - - col_z0 += 4; - col_z1 += 4; - col_z2 += 4; - col_z6 += 4; - col_z7 += 4; - col_z8 += 4; } + col_z0 += 4; + col_z1 += 4; + col_z2 += 4; + col_z6 += 4; + col_z7 += 4; + col_z8 += 4; + } - for (int i = 0; i < mk4; ++i) { - col_z0[i] = 0.0; - col_z1[i] = 0.0; - col_z2[i] = 0.0; + for (int i = 0; i < mk4; ++i) { + col_z0[i] = 0.0; + col_z1[i] = 0.0; + col_z2[i] = 0.0; + if (pad2) { col_z6[i] = 0.0; col_z7[i] = 0.0; col_z8[i] = 0.0; } - col_data += 9 * oosize; - im_data += isize * isize; } - } else if (stride[0] == 2 && filter_height == 3 && pad1 && - dilation[0] == 1 && im_height > 2) { - for (int c = 0; c < im_channels; ++c) { - int oosize = osize * osize; - int nk4 = osize / 4; - int mk4 = osize % 4; - - // 3 2 3 1 0 1 3 2 3 - float *col0 = col_data + 0 * oosize + osize + 1; - float *col1 = col_data + 1 * oosize + osize; - float *col2 = col_data + 2 * oosize + osize; - - float *col3 = col_data + 3 * oosize + 1; - float *col4 = col_data + 4 * oosize; - float *col5 = col_data + 5 * oosize; - - float *col6 = col_data + 6 * oosize + 1; - float *col7 = col_data + 7 * oosize; - float *col8 = col_data + 8 * oosize; - - float32x4x2_t im01; - float32x4x2_t im23; - const float *im_tmp_data0 = im_data; - const float *im_tmp_data2 = im_data + isize; - - for (int j = 0; j < osize; ++j) { - for (int i = 0; i < nk4; ++i) { - im01 = vld2q_f32(im_tmp_data0); - im23 = vld2q_f32(im_tmp_data2); - vst1q_f32(col0, im23.val[1]); - vst1q_f32(col1, im23.val[0]); - vst1q_f32(col2, im23.val[1]); - vst1q_f32(col3, im01.val[1]); - vst1q_f32(col4, im01.val[0]); - vst1q_f32(col5, im01.val[1]); - vst1q_f32(col6, im23.val[1]); - vst1q_f32(col7, im23.val[0]); - vst1q_f32(col8, im23.val[1]); - - col0 += 4; - col1 += 4; - col2 += 4; - col3 += 4; - col4 += 4; - col5 += 4; - col6 += 4; - col7 += 4; - col8 += 4; - - im_tmp_data0 += 8; - im_tmp_data2 += 8; - } - const float *im_tmp_data1 = im_tmp_data0 + 1; - const float *im_tmp_data3 = im_tmp_data2 + 1; - for (int i = 0; i < mk4; ++i) { - *col0 = *im_tmp_data3; - *col1 = *im_tmp_data2; - *col2 = *im_tmp_data3; - *col3 = *im_tmp_data1; - *col4 = *im_tmp_data0; - *col5 = *im_tmp_data1; - *col6 = *im_tmp_data3; - *col7 = *im_tmp_data2; - *col8 = *im_tmp_data3; - - col0++; - col1++; - col2++; - col3++; - col4++; - col5++; - col6++; - col7++; - col8++; - im_tmp_data0 += 2; - im_tmp_data1 += 2; - im_tmp_data2 += 2; - im_tmp_data3 += 2; - } - im_tmp_data0 += (isize - fill); - im_tmp_data2 += (isize - fill); - } - for (int i = 0; i < osize; ++i) { - col_data[0 * oosize + i * osize] = 0.0; - col_data[3 * oosize + i * osize] = 0.0; - col_data[6 * oosize + i * osize] = 0.0; - if (pad2) { - col_data[2 * oosize + osize - 1 + i * osize] = 0.0; - col_data[5 * oosize + osize - 1 + i * osize] = 0.0; - col_data[8 * oosize + osize - 1 + i * osize] = 0.0; - } - } - float32x4_t zero4; - zero4 = vdupq_n_f32(0.0); - auto col_z0 = col_data; - auto col_z1 = col_data + oosize; - auto col_z2 = col_data + 2 * oosize; - auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); - auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); - auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); - - for (int i = 0; i < nk4; ++i) { - vst1q_f32(col_z0, zero4); - vst1q_f32(col_z1, zero4); - vst1q_f32(col_z2, zero4); - if (pad2) { - vst1q_f32(col_z6, zero4); - vst1q_f32(col_z7, zero4); - vst1q_f32(col_z8, zero4); - } - col_z0 += 4; - col_z1 += 4; - col_z2 += 4; - col_z6 += 4; - col_z7 += 4; - col_z8 += 4; - } - - for (int i = 0; i < mk4; ++i) { - col_z0[i] = 0.0; - col_z1[i] = 0.0; - col_z2[i] = 0.0; - if (pad2) { - col_z6[i] = 0.0; - col_z7[i] = 0.0; - col_z8[i] = 0.0; - } - } - col_data[1 * oosize + osize] = im_data[isize]; - for (int i = 1; i < osize; ++i) { - col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1]; - } - col_data[4 * oosize] = im_data[0]; - col_data[7 * oosize] = im_data[isize]; - - col_data += 9 * oosize; - im_data += isize * isize; - } - } else { - for (int c = 0; c < channels_col; ++c) { - int w_offset = c % filter_width; - int h_offset = (c / filter_width) % filter_height; - int c_im = c / (filter_width * filter_height); - for (int h = 0; h < col_height; ++h) { - int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; - for (int w = 0; w < col_width; ++w) { - int im_col_idx = - w * stride[1] - padding[1] + w_offset * dilation[1]; - int col_idx = (c * col_height + h) * col_width + w; - int im_idx = - (im_row_idx + c_im * im_height) * im_width + im_col_idx; - - col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || - im_col_idx < 0 || im_col_idx >= im_width) - ? static_cast(0) - : im_data[im_idx]; - } - } + col_data[1 * oosize + osize] = im_data[isize]; + for (int i = 1; i < osize; ++i) { + col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1]; } + col_data[4 * oosize] = im_data[0]; + col_data[7 * oosize] = im_data[isize]; + + col_data += 9 * oosize; + im_data += isize * isize; } -#else + } else { for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; int h_offset = (c / filter_width) % filter_height; @@ -424,14 +377,122 @@ class Im2ColFunctor { col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || im_col_idx < 0 || im_col_idx >= im_width) - ? static_cast(0) + ? static_cast(0) : im_data[im_idx]; } } } + } +#else + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / (filter_width * filter_height); + for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; + for (int w = 0; w < col_width; ++w) { + int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; + int col_idx = (c * col_height + h) * col_width + w; + int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; + + col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || + im_col_idx < 0 || im_col_idx >= im_width) + ? static_cast(0) + : im_data[im_idx]; + } + } + } #endif +} + +// TODO(hjchen2) +void ExtractToRows1() {} + +void ExtractToRows2() {} + +/* + * im = [input_channels, input_height, input_width] + * col = + * [input_channels, filter_height, filter_width, output_height, + * output_width] + */ +template <> +void Im2ColFunctor::operator()( + const framework::Tensor &im, const std::vector &dilation, + const std::vector &stride, const std::vector &padding, + framework::Tensor *col) { + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; + + int channels_col = im_channels * filter_height * filter_width; + const int8_t *im_data = im.data(); + int8_t *col_data = col->data(); +// #if defined(__ARM_NEON__) || defined(__ARM_NEON) +#if 0 + if (stride[0] == stride[1] && stride[0] == 1 && dilation[0] == 1 && + padding[0] == padding[1] && dilation[0] == dilation[1]) { + // pad 0 + memset(col_data, 0, col->numel() * sizeof(int8_t)); + for (int ic = 0; ic < im_channels; ++ic) { + for (int oh = 0; oh < padding[0]; ++oh) { + for (int k = 0; k < filter_height * filter_width; ++k) { + ExtractToRows1(); + ExtractToRows1(); + } + } + for (int oh = padding[0]; oh < col_height - padding[0]; ++oh) { + for (int k = 0; k < filter_height * filter_width; ++k) { + ExtractToRows1(); + } + } + } + } else if (stride[0] == stride[1] && stride[0] == 2 && dilation[0] == 1 && + padding[0] == padding[1] && dilation[0] == dilation[1]) { + // pad 0 + memset(col_data, 0, col->numel() * sizeof(int8_t)); + for (int ic = 0; ic < im_channels; ++ic) { + for (int oh = 0; oh < padding[0]; ++oh) { + for (int k = 0; k < filter_height * filter_width; ++k) { + ExtractToRows2(); + ExtractToRows2(); + } + } + for (int oh = padding[0]; oh < col_height - padding[0]; ++oh) { + for (int k = 0; k < filter_height * filter_width; ++k) { + ExtractToRows2(); + } + } + } + } else { +#endif + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / (filter_width * filter_height); + for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; + for (int w = 0; w < col_width; ++w) { + int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; + int col_idx = (c * col_height + h) * col_width + w; + int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; + + col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || + im_col_idx < 0 || im_col_idx >= im_width) + ? static_cast(0) + : im_data[im_idx]; + } + } } -}; +// #if defined(__ARM_NEON__) || defined(__ARM_NEON) +#if 0 + } +#endif +} /* * im = [input_channels, input_height, input_width] @@ -456,27 +517,6 @@ class Col2ImFunctor { int col_height = col.dims()[3]; int col_width = col.dims()[4]; - // PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - // - - // ((dilation[0] * (filter_height - 1) - // + 1))) / - // stride[0] + - // 1, - // col_height, - // "Output_height and - // padding(padding_up, padding_down) - // are " "inconsistent."); - // PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - // - - // ((dilation[1] * (filter_width - 1) - // + 1))) / - // stride[1] + - // 1, - // col_width, - // "Output_height and - // padding(padding_up, padding_down) - // are " "inconsistent."); - int channels_col = im_channels * filter_height * filter_width; T *im_data = im->data(); @@ -503,9 +543,9 @@ class Col2ImFunctor { }; template class Im2ColFunctor; -// template class Im2ColFunctor; +template class Im2ColFunctor; template class Col2ImFunctor; -template class Col2ImFunctor; +template class Col2ImFunctor; /* * im = [input_channels, input_height, input_width] @@ -519,8 +559,6 @@ class Im2ColFunctor { void operator()(const framework::Tensor &im, const std::vector &dilation, const std::vector &stride, const std::vector &padding, framework::Tensor *col) { - // PADDLE_ENFORCE(im.dims().size() == 3); - // PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; @@ -529,19 +567,6 @@ class Im2ColFunctor { int col_height = col->dims()[0]; int col_width = col->dims()[1]; - // PADDLE_ENFORCE_EQ( - // (im_height + padding[0] + padding[2] - - // filter_height) / stride[0] - // + 1, col_height, "Output_height and - // padding(padding_up, - // padding_down) are " "inconsistent."); - // PADDLE_ENFORCE_EQ( - // (im_width + padding[1] + padding[3] - - // filter_width) / stride[1] + - // 1, col_width, "col_width and padding(padding_left, - // padding_right) - // are " "inconsistent."); - const T *im_data = im.data(); T *col_data = col->data(); @@ -593,8 +618,6 @@ class Col2ImFunctor { const std::vector &dilation, const std::vector &stride, const std::vector &padding, framework::Tensor *im) { - // PADDLE_ENFORCE(im->dims().size() == 3); - // PADDLE_ENFORCE(col.dims().size() == 5); int im_channels = im->dims()[0]; int im_height = im->dims()[1]; int im_width = im->dims()[2]; @@ -603,19 +626,6 @@ class Col2ImFunctor { int col_height = col.dims()[0]; int col_width = col.dims()[1]; - // PADDLE_ENFORCE_EQ( - // (im_height + padding[0] + padding[2] - - // filter_height) / stride[0] - // + 1, col_height, "Output_height and - // padding(padding_up, - // padding_down) are " "inconsistent."); - // PADDLE_ENFORCE_EQ( - // (im_width + padding[1] + padding[3] - - // filter_width) / stride[1] + - // 1, col_width, "col_width and padding(padding_left, - // padding_right) - // are " "inconsistent."); - T *im_data = im->data(); const T *col_data = col.data(); @@ -655,9 +665,7 @@ class Col2ImFunctor { }; template class Im2ColFunctor; -template class Im2ColFunctor; template class Col2ImFunctor; -template class Col2ImFunctor; } // namespace math } // namespace operators diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index 9d39f89b04..fc4c385add 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -135,7 +135,7 @@ template struct ClearTensor { void operator()(framework::Tensor *tensor) { auto size = tensor->numel(); - auto *tensor_data = tensor->data(); + auto *tensor_data = tensor->data(); memset((void *)tensor_data, 0, sizeof(T) * size); // NOLINT } }; @@ -151,9 +151,9 @@ struct RowwiseAdd { PADDLE_MOBILE_ENFORCE((output->dims() == in_dims), "output->dims() must be equal to in_dims."); - auto *input_data = input.data(); - auto *out_data = output->data(); - auto *vec_data = vector.data(); + auto *input_data = input.data(); + auto *out_data = output->data(); + auto *vec_data = vector.data(); for (int64_t i = 0; i < in_dims[0]; ++i) { for (int64_t j = 0; j < size; ++j) { out_data[i * size + j] = input_data[i * size + j] + vec_data[j]; diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index de19e3df2a..da8f3e042a 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include "framework/tensor.h" namespace paddle_mobile { @@ -25,7 +26,7 @@ template void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu = false, - float *bias = nullptr); + T *bias = nullptr); template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp new file mode 100644 index 0000000000..70677223d1 --- /dev/null +++ b/src/operators/math/math_function_int8.cpp @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "operators/math/gemm.h" +#include "operators/math/math_function.h" + +namespace paddle_mobile { +namespace operators { +namespace math { +template <> +void matmul(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, + int8_t alpha, framework::Tensor *matrix_out, int8_t beta, + bool relu, int8_t *bias) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_MOBILE_ENFORCE( + dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + int32_t M = dim_out[0]; + int32_t N = dim_out[1]; + int32_t K = (!trans_a) ? dim_a[1] : dim_a[0]; + Gemm gemm; + + if (trans_a) { + int32_t numel = matrix_a.numel(); + int32_t m = matrix_a.dims()[0]; + int32_t n = matrix_a.dims()[1]; + int8_t *tmp = (int8_t *)(matrix_a.data()); // NOLINT + int8_t *a = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * numel)); + int32_t index = 0; + for (int32_t j = 0; j < n; j++) { + for (int32_t i = 0; i < m; i++) { + a[index++] = tmp[i * n + j]; + } + } + + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } else { + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), N, + relu, bias); + } +} +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/vol2col.cpp b/src/operators/math/vol2col.cpp index afee3f7f85..9311e9e229 100644 --- a/src/operators/math/vol2col.cpp +++ b/src/operators/math/vol2col.cpp @@ -32,9 +32,6 @@ class Vol2ColFunctor { void operator()(const Tensor &vol, const std::vector &dilations, const std::vector &strides, const std::vector &paddings, Tensor *col) const { - // PADDLE_ENFORCE(vol.dims().size() == 4); - // PADDLE_ENFORCE(col->dims().size() == 7); - int input_channels = vol.dims()[0]; int input_depth = vol.dims()[1]; int input_height = vol.dims()[2]; @@ -48,32 +45,6 @@ class Vol2ColFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; - // PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - - // ((dilations[0] * (filter_depth - 1) - // + 1))) / - // strides[0] + - // 1, - // output_depth, - // "input_depth and output_depth are " - // "mismatching."); - // PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - - // ((dilations[1] * (filter_height - - // 1) + 1))) / - // strides[1] + - // 1, - // output_height, - // "input_height and output_height are - // " - // "mismatching."); - // PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - - // ((dilations[2] * (filter_width - 1) - // + 1))) / - // strides[2] + - // 1, - // output_width, - // "input_width and output_width are " - // "mismatching."); - const T *vol_data = vol.data(); T *col_data = col->data(); @@ -119,9 +90,6 @@ class Col2VolFunctor { void operator()(const Tensor &col, const std::vector &dilations, const std::vector &strides, const std::vector &paddings, Tensor *vol) const { - // PADDLE_ENFORCE(vol->dims().size() == 4); - // PADDLE_ENFORCE(col.dims().size() == 7); - int input_channels = vol->dims()[0]; int input_depth = vol->dims()[1]; int input_height = vol->dims()[2]; @@ -135,31 +103,6 @@ class Col2VolFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; - // PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - - // ((dilations[0] * (filter_depth - 1) - // + 1))) / - // strides[0] + - // 1, - // output_depth, - // "input_depth and output_depth are " - // "mismatching."); - // PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - - // ((dilations[1] * (filter_height - - // 1) + 1))) / - // strides[1] + - // 1, - // output_height, - // "input_height and output_height are - // " - // "mismatching."); - // PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - - // ((dilations[2] * (filter_width - 1) - // + 1))) / - // strides[2] + - // 1, - // output_width, - // "input_width and output_width are " - // "mismatching."); T *vol_data = vol->data(); const T *col_data = col.data(); @@ -195,9 +138,9 @@ class Col2VolFunctor { }; template class Vol2ColFunctor; -template class Vol2ColFunctor; +template class Vol2ColFunctor; template class Col2VolFunctor; -template class Col2VolFunctor; +template class Col2VolFunctor; } // namespace math } // namespace operators diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 1c707f960d..9c89a5b9b9 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -2150,14 +2150,12 @@ class QuantizeParam : public OpParam { const AttributeMap &attrs, const Scope &scope) { input_ = InputXFrom(inputs, scope); out_ = OutFrom(outputs, scope); - if (HasAttr("is_static", attrs)) { - is_static_ = GetAttr("is_static", attrs); - } // online // scale = max(abs(x)) online_scale_ = GetVarValue("OutScale", outputs, scope); // offline if (HasAttr("static_scale", attrs)) { + is_static_ = true; static_scale_ = GetAttr("static_scale", attrs); } // x = round(scale * x) @@ -2179,7 +2177,7 @@ class QuantizeParam : public OpParam { float static_scale_ = 1.0f; // round method type // nearest_zero and nearest_even is valid currently - RoundType round_type_ = ROUND_NEAREST_TO_EVEN; + RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; }; template diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1209b4e3f5..b38ff2e47a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -258,6 +258,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp) target_link_libraries(test-gemm-accuracy paddle-mobile) + # gen test + ADD_EXECUTABLE(test-gemm-int8-accuracy common/test_gemm_int8_accuracy.cpp) + target_link_libraries(test-gemm-int8-accuracy paddle-mobile) + # gen test ADD_EXECUTABLE(test-gemm-perf common/test_gemm_perf.cpp) target_link_libraries(test-gemm-perf paddle-mobile) diff --git a/test/common/test_gemm_accuracy.cpp b/test/common/test_gemm_accuracy.cpp index 0967094f68..2a2505a86b 100644 --- a/test/common/test_gemm_accuracy.cpp +++ b/test/common/test_gemm_accuracy.cpp @@ -84,7 +84,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { } paddle_mobile::operators::math::Gemm gemm; - gemm.SgemmWithBn(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias, + gemm.SgemmWithBn(m, n, k, 1, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias, nullptr); int eq = 0; int neq = 0; diff --git a/test/common/test_gemm_int8_accuracy.cpp b/test/common/test_gemm_int8_accuracy.cpp new file mode 100644 index 0000000000..80ddd40e12 --- /dev/null +++ b/test/common/test_gemm_int8_accuracy.cpp @@ -0,0 +1,131 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "../test_helper.h" +#include "common/log.h" +#include "memory/t_malloc.h" +#include "operators/math/gemm.h" + +#define a(i, j) a[(i)*lda + (j)] +#define b(i, j) b[(i)*ldb + (j)] +#define c(i, j) c[(i)*ldc + (j)] +#define c1(i, j) c1[(i)*ldc + (j)] + +using std::default_random_engine; +using std::uniform_int_distribution; + +void print_matirx(int m, int n, int ldc, int32_t *c) { + for (int i = 0; i < m; ++i) { + std::cout << c(i, 0); + for (int j = 1; j < n; ++j) { + std::cout << " | " << c(i, j); + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +void print_matirx(int m, int n, int ldc, int8_t *c) { + for (int i = 0; i < m; ++i) { + std::cout << static_cast(c(i, 0)); + for (int j = 1; j < n; ++j) { + std::cout << " | " << static_cast(c(i, j)); + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +int do_sgemm(int m, int n, int k, bool relu, int pr) { + int lda = k; + int ldb = n; + int ldc = n; + default_random_engine e; + uniform_int_distribution pixel(-127, 127); + int8_t *a = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k)); + int8_t *b = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n)); + int32_t *c = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n)); + int32_t *c1 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n)); + + for (int i = 0; i < m * k; ++i) { + a[i] = pixel(e); + } + for (int i = 0; i < k * n; ++i) { + b[i] = pixel(e); + } + + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + int32_t r = 0; + for (int p = 0; p < k; p++) { + r += static_cast(a(i, p)) * static_cast(b(p, j)); + } + c1(i, j) = r; + } + } + + paddle_mobile::operators::math::Gemm gemm; + gemm.Sgemm(m, n, k, static_cast(1), a, lda, b, ldb, + static_cast(0), c, ldc, relu, nullptr); + int eq = 0; + int neq = 0; + for (int i = 0; i < m * n; ++i) { + if (c[i] == c1[i]) { + ++eq; + } else { + ++neq; + } + } + + if (pr > 0) { + std::cout << "A:" << std::endl; + print_matirx(m, k, lda, a); + std::cout << "B:" << std::endl; + print_matirx(k, n, ldb, b); + std::cout << "C:" << std::endl; + print_matirx(m, n, ldc, c); + std::cout << "C1:" << std::endl; + print_matirx(m, n, ldc, c1); + } + + std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu + << " eq=" << eq << " neq=" << neq << std::endl; + + paddle_mobile::memory::Free(a); + paddle_mobile::memory::Free(b); + paddle_mobile::memory::Free(c); + paddle_mobile::memory::Free(c1); + + return 0; +} + +int main() { + do_sgemm(9, 9, 9, false, 10); + do_sgemm(10, 6, 12, false, 0); + do_sgemm(512, 256, 384, false, 0); + do_sgemm(1366, 768, 256, false, 0); + do_sgemm(1255, 755, 333, false, 0); + do_sgemm(555, 777, 999, false, 0); + do_sgemm(1024, 1024, 1024, false, 0); + + return 0; +} diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 386c09d71a..89f0012ae8 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -28,13 +28,11 @@ limitations under the License. */ int main() { paddle_mobile::PaddleMobile paddle_mobile; - paddle_mobile.SetThreadNum(4); - Tensor aa, bb, cc, scale, bias; + paddle_mobile.SetThreadNum(1); + Tensor aa, bb, cc; auto aaptr = aa.mutable_data({m, k}); auto bbptr = bb.mutable_data({k, n}); auto ccptr = cc.mutable_data({m, n}); - auto scaleptr = scale.mutable_data({m}); - auto biasptr = bias.mutable_data({m}); for (int i = 0; i < m * k; ++i) { aaptr[i] = 2; @@ -45,23 +43,55 @@ int main() { for (int i = 0; i < m * n; ++i) { ccptr[i] = 2; } - for (int i = 0; i < m; ++i) { - scaleptr[i] = 1; - biasptr[i] = 0; + + Tensor aa_int8, bb_int8, cc_int8; + auto aaptr_int8 = aa_int8.mutable_data({m, k}); + auto bbptr_int8 = bb_int8.mutable_data({k, n}); + auto ccptr_int8 = cc_int8.mutable_data({m, n}); + + for (int i = 0; i < m * k; ++i) { + aaptr_int8[i] = static_cast(2); + } + for (int i = 0; i < k * n; ++i) { + bbptr_int8[i] = static_cast(2); + } + for (int i = 0; i < m * n; ++i) { + ccptr_int8[i] = static_cast(2); } - auto time1 = time(); + // float + // warm-up 10 times for (int j = 0; j < 10; ++j) { paddle_mobile::operators::math::matmul( aa, false, bb, false, static_cast(1), &cc, static_cast(0), - false, biasptr); + false, nullptr); + } - // paddle_mobile::operators::math::matmulWithBn( - // aa, false, bb, false, static_cast(1), &cc, - // static_cast(0), true, &scale, &bias, 0); + auto time1 = time(); + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa, false, bb, false, static_cast(1), &cc, static_cast(0), + false, nullptr); } auto time2 = time(); - std::cout << "gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; + std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; + + // int8_t + // warm-up 10 times + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + static_cast(0), false, nullptr); + } + + auto time3 = time(); + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + static_cast(0), false, nullptr); + } + auto time4 = time(); + std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; return 0; } diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index a2f030eeac..c88a78974c 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -25,27 +25,31 @@ int main() { paddle_mobile::PaddleMobile paddle_mobile; #endif - paddle_mobile.SetThreadNum(4); - bool optimize = true; + paddle_mobile.SetThreadNum(1); + bool optimize = false; auto time1 = time(); if (paddle_mobile.Load(g_googlenet, optimize)) { auto time2 = time(); std::cout << "load cost :" << time_diff(time1, time2) << "ms" << std::endl; std::vector input; + std::vector output; std::vector dims{1, 3, 224, 224}; GetInput(g_test_image_1x3x224x224, &input, dims); - // 预热十次 - for (int i = 0; i < 10; ++i) { - auto vec_result = paddle_mobile.Predict(input, dims); - } + // // 预热十次 + // for (int i = 0; i < 10; ++i) { + // output = paddle_mobile.Predict(input, dims); + // } auto time3 = time(); for (int i = 0; i < 10; ++i) { - auto vec_result = paddle_mobile.Predict(input, dims); + output = paddle_mobile.Predict(input, dims); } auto time4 = time(); std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" << std::endl; + for (int i = 0; i < output.size(); ++i) { + DLOG << "result[" << i << "] = " << output[i]; + } } return 0; } diff --git a/test/operators/test_dequantize_op.cpp b/test/operators/test_dequantize_op.cpp index 8c61ae32d9..8e89d8f7af 100644 --- a/test/operators/test_dequantize_op.cpp +++ b/test/operators/test_dequantize_op.cpp @@ -59,7 +59,7 @@ int TestDequqntizeOp() { framework::Tensor output_cmp; output_cmp.Resize(dim); - float dequant_scale = 1.f / (1.27 * 1.74); + float dequant_scale = 1.27 / 1.74; dequantize(input, dequant_scale, &output_cmp); const float* output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { diff --git a/test/operators/test_int8_conv_op.cpp b/test/operators/test_int8_conv_op.cpp index 4ebc24a9e6..fd9e45e9a2 100644 --- a/test/operators/test_int8_conv_op.cpp +++ b/test/operators/test_int8_conv_op.cpp @@ -140,10 +140,10 @@ int TestConvOp() { int dilation_w = 1; int batch_size = 1; - int input_c = 3; - int input_h = 25; - int input_w = 25; - int output_c = 3; + int input_c = 63; + int input_h = 51; + int input_w = 51; + int output_c = 125; framework::DDim input_shape = framework::make_ddim({batch_size, input_c, input_h, input_w}); framework::DDim filter_shape = @@ -158,11 +158,11 @@ int TestConvOp() { auto input_var = scope.get()->Var("input"); auto input = input_var->template GetMutable(); - SetupTensor(input, input_shape, -127, 127); + SetupTensor(input, input_shape, -20, 20); auto filter_var = scope.get()->Var("filter"); auto filter = filter_var->template GetMutable(); - SetupTensor(filter, filter_shape, -127, 127); + SetupTensor(filter, filter_shape, -20, 20); auto output_var = scope.get()->Var("output"); framework::AttributeMap attrs; @@ -174,28 +174,40 @@ int TestConvOp() { auto *op = new operators::ConvOp("conv2d", inputs, outputs, attrs, scope); + struct timespec ts_begin, ts_end; op->InferShape(); + // warmup op->Run(); - - int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; - int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; - int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1; - int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1; - auto output_shape = framework::make_ddim( - std::vector({batch_size, output_c, output_h, output_w})); - framework::Tensor output_cmp; - output_cmp.mutable_data(output_shape); - conv2d(input, filter, attrs, &output_cmp); - - // compare results - auto output = output_var->template Get(); - const Otype *output_data = output->data(); - Otype *output_cmp_data = output_cmp.data(); - for (int i = 0; i < output->numel(); ++i) { - PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], - "output[%d] = %d, output_cmp[%d] = %d", i, - output_data[i], i, output_cmp_data[i]); + clock_gettime(CLOCK_MONOTONIC, &ts_begin); + for (int i = 0; i < 10; ++i) { + op->Run(); } + clock_gettime(CLOCK_MONOTONIC, &ts_end); + uint64_t elapsed = (ts_end.tv_sec - ts_begin.tv_sec) * 1e3 + + (ts_end.tv_nsec - ts_begin.tv_nsec) / 1e6; + LOG(kLOG_INFO) << "elapsed: " << elapsed / 10.0 << " ms"; + + /* + int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; + int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1; + int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1; + auto output_shape = framework::make_ddim( + std::vector({batch_size, output_c, output_h, output_w})); + framework::Tensor output_cmp; + output_cmp.mutable_data(output_shape); + conv2d(input, filter, attrs, &output_cmp); + + // compare results + auto output = output_var->template Get(); + const Otype *output_data = output->data(); + Otype *output_cmp_data = output_cmp.data(); + for (int i = 0; i < output->numel(); ++i) { + PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], + "output[%d] = %d, output_cmp[%d] = %d", i, + output_data[i], i, output_cmp_data[i]); + } + */ delete op; return 0; } @@ -203,12 +215,42 @@ int TestConvOp() { } // namespace paddle_mobile int main() { + // kernel = 7, pad = 0, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2"; + paddle_mobile::TestConvOp(); + + // kernel = 7, pad = 3, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2"; + paddle_mobile::TestConvOp(); + // kernel = 3, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1"; paddle_mobile::TestConvOp(); + // kernel = 3, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1"; + paddle_mobile::TestConvOp(); + LOG(paddle_mobile::kLOG_INFO) << "\n"; + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1"; paddle_mobile::TestConvOp(); + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1"; + paddle_mobile::TestConvOp(); + LOG(paddle_mobile::kLOG_INFO) << "\n"; + // kernel = 5, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1"; paddle_mobile::TestConvOp(); + // kernel = 5, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1"; + paddle_mobile::TestConvOp(); + LOG(paddle_mobile::kLOG_INFO) << "\n"; + // kernel = 5, pad = 2, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1"; paddle_mobile::TestConvOp(); + // kernel = 5, pad = 2, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1"; + paddle_mobile::TestConvOp(); } diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 8ebf092689..3080100e70 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -12,80 +12,80 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "../test_helper.h" #include "../test_include.h" #include "operators/mul_op.h" -int main() { - paddle_mobile::Loader loader; - auto program = loader.Load(g_resnet); - PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, - "program file read fail"); - - Executor4Test> - executor(program, "mul"); - - // 1. input_tensors; - vector input_tensors; - - Tensor input1; - auto input1_data = CreateInput(&input1, {3, 2, 1, 1}, 0, 1); - input_tensors.push_back(input1); - Tensor input2; - auto input2_data = CreateInput(&input2, {2, 3}, 0, 1); - input_tensors.push_back(input2); - - // 2. input_names - vector input_names({ - "pool2d_0.tmp_0", - "fc_0.w_0", - }); - - // 3. output_names - vector output_names({"fc_0.tmp_0"}); - - // 4. out_dims; - vector out_ddims; - auto out_ddim = paddle_mobile::framework::make_ddim({3, 3}); - out_ddims.push_back(out_ddim); - - auto output = executor.Predict(input_tensors, input_names, - output_names, out_ddims); - - auto output0_data = output[0]->data(); - - auto dim_1 = input1.numel() / input1.dims()[0]; - DLOG << " input1 : "; - for (int i = 0; i < input1.dims()[0]; ++i) { - for (int j = 0; j < dim_1; ++j) { - DLOGF("%f ", input1_data[i * dim_1 + j]); +#define a(i, j) a[(i)*lda + (j)] +#define b(i, j) b[(i)*ldb + (j)] +#define c(i, j) c[(i)*ldc + (j)] + +namespace paddle_mobile { +using framework::AttributeMap; +using framework::DDim; +using framework::Scope; +using framework::make_ddim; +template +int TestMulOP() { + int32_t m = 1024; + int32_t n = 1024; + int32_t k = 1024; + int32_t lda = k; + int32_t ldb = n; + int32_t ldc = n; + DDim inputA_shape = make_ddim({m, k}); + DDim inputB_shape = make_ddim({k, n}); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"inputA"}); + inputs["Y"] = std::vector({"inputB"}); + outputs["Out"] = std::vector({"output"}); + + auto inputA_var = scope.get()->Var("inputA"); + auto inputA = inputA_var->template GetMutable(); + SetupTensor(inputA, inputA_shape, -127, 127); + auto inputB_var = scope.get()->Var("inputB"); + auto inputB = inputB_var->template GetMutable(); + SetupTensor(inputB, inputB_shape, -127, 127); + + auto output_var = scope.get()->Var("output"); + AttributeMap attrs; + attrs["x_num_col_dims"].Set(1); + attrs["y_num_col_dims"].Set(1); + auto *op = + new operators::MulOp("mul", inputs, outputs, attrs, scope); + op->InferShape(); + op->Run(); + auto output = output_var->template Get(); + const O *output_data = output->data(); + // compare + O *c = static_cast(memory::Alloc(sizeof(O) * m * n)); + I *a = inputA->data(); + I *b = inputB->data(); + for (int32_t i = 0; i < m; ++i) { + for (int32_t j = 0; j < n; ++j) { + O r = 0; + for (int32_t p = 0; p < k; p++) { + r += static_cast(a(i, p)) * static_cast(b(p, j)); + } + c(i, j) = r; } - DLOGF("\n"); } - auto dim_2 = input2.numel() / input2.dims()[0]; - DLOG << " input2 : "; - for (int i = 0; i < input2.dims()[0]; ++i) { - for (int j = 0; j < dim_2; ++j) { - DLOGF("%f ", input2_data[i * dim_2 + j]); - } - DLOGF("\n"); - } - - auto dim_output0 = output[0]->numel() / output[0]->dims()[0]; - DLOG << " output : "; - for (int i = 0; i < output[0]->dims()[0]; ++i) { - for (int j = 0; j < dim_output0; ++j) { - DLOGF("%f ", output0_data[i * dim_2 + j]); - } - DLOGF("\n"); + for (int32_t i = 0; i < m * n; ++i) { + PADDLE_MOBILE_ENFORCE( + output_data[i] == c[i], "output[%d] = %d, output_cmp[%d] = %d", i, + static_cast(output_data[i]), i, static_cast(c[i])); } + DLOG << "Run MulOp successfully!"; + delete op; + return 0; +} +} // namespace paddle_mobile - /// output (3,3) - DLOG << "output memory size : " << output[0]->memory_size(); - DLOG << "output numel : " << output[0]->numel(); - - DLOG << input1_data[0] << " x " << input2_data[0] << " + " << input1_data[1] - << " x " << input2_data[0 + 3] << " = " << output0_data[0]; +int main() { + paddle_mobile::TestMulOP(); + paddle_mobile::TestMulOP(); return 0; } diff --git a/test/operators/test_quantize_op.cpp b/test/operators/test_quantize_op.cpp index c988862f6d..5b1f276beb 100644 --- a/test/operators/test_quantize_op.cpp +++ b/test/operators/test_quantize_op.cpp @@ -18,14 +18,6 @@ limitations under the License. */ namespace paddle_mobile { -// static float g_test_data[50] = { -// -5.55, -5.5, -5.45, -5.0, -4.55, -4.5, -4.45, -4.0, -3.55, -3.5, -// -3.45, -3.01, -2.75, -2.5, -2.501, -2.49, -2.01, -1.75, -1.5, -1.25, -// -1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0, 1.25, -// 1.5, 1.75, 2.01, 2.49, 2.501, 2.5, 2.75, 3.01, 3.45, 3.5, -// 3.55, 4.0, 4.45, 4.5, 4.55, 5.0, 5.45, 5.5, 5.55, 6.0, -// }; - static float find_abs_max(const Tensor *input) { float max_abs = 0.f; const float *x = input->data(); @@ -60,6 +52,16 @@ static void quantize_round_to_even(const Tensor *input, const float scale, } } +static void quantize_round_to_nearest(const Tensor *input, const float scale, + Tensor *output) { + const float *x = input->data(); + int8_t *y = output->mutable_data(); + size_t size = input->numel(); + for (size_t i = 0; i < size; ++i) { + y[i] = round(x[i] * scale); + } +} + int TestQuqntizeOp() { framework::DDim dim = framework::make_ddim({1, 3, 224, 224}); @@ -88,15 +90,16 @@ int TestQuqntizeOp() { auto output_scale = output_scale_var->template Get(); const float *output_scale_data = output_scale->data(); - float max_abs = find_abs_max(input); - float output_scale_cmp = 127 / max_abs; + float output_scale_cmp = find_abs_max(input); PADDLE_MOBILE_ENFORCE(output_scale_cmp == output_scale_data[0], "output_scale = %.6f, output_scale_cmp = %.6f", output_scale_cmp, output_scale_data[0]); framework::Tensor output_cmp; output_cmp.Resize(dim); - quantize_round_to_even(input, output_scale_cmp, &output_cmp); + float scale = 127 / output_scale_cmp; + // quantize_round_to_even(input, scale, &output_cmp); + quantize_round_to_nearest(input, scale, &output_cmp); int8_t *output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], -- GitLab