diff --git a/src/common/types.cpp b/src/common/types.cpp index 312e491a35681e2fc75584106160a4c79e22e372..cf2c4dc87613b4641d7c1126e22d2e4a45ff9594 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -24,6 +24,7 @@ const char *G_OP_TYPE_CONCAT = "concat"; const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant"; const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; +const char *G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8 = "fusion_conv_add_relu_int8"; const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu"; const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu"; const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu"; @@ -115,6 +116,7 @@ std::unordered_map< {G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}}, {G_OP_TYPE_FILL_CONSTANT, {{}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}}, + {G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_PRELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_IM2SEQUENCE, {{"X"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index 16ed1aef57432249b14c415b3a23042ca295b600..a63d2efd23ebdef1ebb0b6d40d356c33574b3818 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -99,6 +99,7 @@ extern const char *G_OP_TYPE_BOX_CODER; extern const char *G_OP_TYPE_CONCAT; extern const char *G_OP_TYPE_ELEMENTWISE_ADD; extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU; +extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8; extern const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU; extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU; extern const char *G_OP_TYPE_FC; diff --git a/src/framework/op_registry.h b/src/framework/op_registry.h index 219385ab1429fefddc9d380799259f7562e0030f..52cae493ea0d62bc06df70933774f8c808b7c93d 100644 --- a/src/framework/op_registry.h +++ b/src/framework/op_registry.h @@ -98,6 +98,24 @@ class OpRegistry { } }; +#define REGISTER_OPERATOR_INT8(op_type, op_class, device_name, device_type) \ + template class op_class; \ + template \ + class _OpClass_##op_type##_##device_name : public op_class { \ + public: \ + DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_##device_name, op_class); \ + }; \ + static paddle_mobile::framework::OperatorRegistrar< \ + device_type, _OpClass_##op_type##_##device_name> \ + __op_registrar_##op_type##_##device_name(#op_type); \ + int TouchOpRegistrar_##op_type##_##device_name() { \ + __op_registrar_##op_type##_##device_name.Touch(); \ + return 0; \ + } + +#define REGISTER_OPERATOR_CPU_INT8(op_type, op_class) \ + REGISTER_OPERATOR_INT8(op_type, op_class, cpu, paddle_mobile::CPU); + #define REGISTER_OPERATOR(op_type, op_class, device_name, device_type) \ template class op_class; \ template \ diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 4b50f15a868e3bdbb8434af0cc0d49a6cb54c6a5..cb7051468715179e1d9a5ead407941a20d9cb87a 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -153,7 +153,8 @@ double PaddleMobile::GetPredictTime() { paddle_mobile::operators::math::Gemm gemm; auto time1 = paddle_mobile::time(); gemm.Sgemm(m, n, k, static_cast(1), a, lda, b, ldb, - static_cast(0), c, ldc, false, nullptr); + static_cast(0), c, ldc, false, + static_cast(nullptr)); auto time2 = paddle_mobile::time(); double cost = paddle_mobile::time_diff(time1, time2); paddle_mobile::memory::Free(a); diff --git a/src/operators/fusion_conv_add_relu_int8_op.cpp b/src/operators/fusion_conv_add_relu_int8_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ac0226ec7ad7b9744cda10a879b8d29a21a8e152 --- /dev/null +++ b/src/operators/fusion_conv_add_relu_int8_op.cpp @@ -0,0 +1,56 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDRELU_INT8_OP + +#include "operators/fusion_conv_add_relu_int8_op.h" +#include +#include "operators/math/conv_func.h" + +namespace paddle_mobile { +namespace operators { + +template +void FusionConvAddReluInt8Op::InferShape() const { + auto in_dims = this->param_.Input()->dims(); + auto filter_dims = this->param_.Filter()->dims(); + const std::vector &strides = this->param_.Strides(); + std::vector paddings = this->param_.Paddings(); + int groups = this->param_.Groups(); + std::vector dilations = this->param_.Dilations(); + + PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && + dilations.size() == paddings.size() && + paddings.size() == strides.size()), + "ConvParam is not suitable"); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back( + math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], + paddings[i], strides[i])); + } + framework::DDim ddim = framework::make_ddim(output_shape); + this->param_.Output()->Resize(ddim); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU_INT8(fusion_conv_add_relu_int8, + ops::FusionConvAddReluInt8Op); +#endif +#endif // FUSION_CONVADDRELU_INT8_OP diff --git a/src/operators/fusion_conv_add_relu_int8_op.h b/src/operators/fusion_conv_add_relu_int8_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5e4b4c08065de8111ae5511b5e9448bacda74c8b --- /dev/null +++ b/src/operators/fusion_conv_add_relu_int8_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDRELU_INT8_OP +#pragma once +#include +#include "framework/operator.h" +#include "operators/kernel/conv_add_relu_kernel.h" +#include "operators/op_param.h" +namespace paddle_mobile { +namespace operators { +template +class FusionConvAddReluInt8Op + : public framework::OperatorWithKernel< + DeviceType, FusionConvAddReluParam, + operators::ConvAddReluKernel> { + public: + FusionConvAddReluInt8Op(const std::string &type, + const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionConvAddReluParam, + operators::ConvAddReluKernel>(type, inputs, outputs, + attrs, scope) {} + void InferShape() const override; +}; +} // namespace operators +} // namespace paddle_mobile +#endif // FUSION_CONVADDRELU_INT8_OP diff --git a/src/operators/kernel/arm/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/conv_add_relu_kernel.cpp index 211d6d8487bfd4afc71d74e5ecbff149ad34e466..150bf1d77e33b99cbd7786f3885f2012270c0c78 100644 --- a/src/operators/kernel/arm/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/conv_add_relu_kernel.cpp @@ -28,10 +28,24 @@ bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { template <> void ConvAddReluKernel::Compute( const FusionConvAddReluParam ¶m) { - ConvAddReluCompute(param); + ConvAddReluCompute(param); } template class ConvAddReluKernel; +#ifdef FUSION_CONVADDRELU_INT8_OP +template <> +bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { + return true; +} + +template <> +void ConvAddReluKernel::Compute( + const FusionConvAddReluParam ¶m) { + ConvAddReluCompute(param); +} +template class ConvAddReluKernel; +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h index 36886b9e2ccfaaa3f557eb7941e294a42b5edb94..9e46790cfe6f8d21f6c466c64853b5efc7db927c 100644 --- a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h @@ -25,21 +25,31 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -template +template void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); Tensor bias = *param.Bias(); - int axis = param.Axis(); + int32_t axis = param.Axis(); + S *bias_data = bias.data(); Tensor *output = param.Output(); - float *biase_data = bias.data(); + output->mutable_data

(); - int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); + float alpha = 1.0f; + float beta = 1.0f; - const int batch_size = static_cast(input->dims()[0]); +#ifdef FUSION_CONVADDRELU_INT8_OP + Tensor scale = *param.InputScale(); + alpha = scale.data()[0]; + beta = 0.0f; +#endif + + int32_t groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int32_t batch_size = static_cast(input->dims()[0]); std::vector filter_shape_vec(framework::vectorize(filter.dims())); @@ -61,13 +71,13 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { Tensor col; Tensor col_matrix; if (is_expand) { - col.mutable_data(col_shape); + col.mutable_data

(col_shape); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); + input->dims(), 1, static_cast(input->dims().size())); framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; @@ -77,17 +87,17 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { output->numel() / (output->dims()[0] * output->dims()[1])}; // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; + int32_t in_step = static_cast(input->dims()[1]) / groups; + int32_t out_step = static_cast(output->dims()[1]) / groups; - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; - for (int i = 0; i < batch_size; i++) { + for (int32_t i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { + for (int32_t g = 0; g < groups; g++) { Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); if (!is_expand) { @@ -97,8 +107,8 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { } else if (data_dim == 2U) { // im2col im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, &col); } else if (data_dim == 3U) { // vol2col @@ -108,9 +118,9 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(1), true, biase_data); + + math::matmul(filter_slice, false, col_matrix, false, alpha, &out_slice, + beta, true, bias_data); } } } diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 7f40157c30ad19472045eb53bd7a99e577429db5..11667dfcc9cf2e25712a8f5c57d665cd41e9a9c6 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -106,10 +106,9 @@ inline void GemmConv(const ConvParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(0)); + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, static_cast(0), + false, static_cast(nullptr)); } } } diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index 07e634e3be9648520357871d91d6677aec6b5c0e..8b9dad90a0b02ebf761bcd44fabc18905b056e6e 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -73,8 +73,9 @@ void MulCompute(const MulParam ¶m) { } if (param.InputX()->type() == typeid(int8_t)) { out->mutable_data(); - math::matmul(x_matrix, false, y_matrix, false, - static_cast(1), out, static_cast(0)); + math::matmul(x_matrix, false, y_matrix, false, + static_cast(1), out, + static_cast(0)); } else { out->mutable_data(); diff --git a/src/operators/kernel/central-arm-func/pool_arm_func.h b/src/operators/kernel/central-arm-func/pool_arm_func.h index 941c237865707bce854aedba56029a4f5de9b2bf..15a2ebc65ee8dc90e148dc13e44596d63abdf35c 100644 --- a/src/operators/kernel/central-arm-func/pool_arm_func.h +++ b/src/operators/kernel/central-arm-func/pool_arm_func.h @@ -23,20 +23,22 @@ namespace paddle_mobile { namespace operators { using framework::Tensor; -inline void PoolBasic(std::string pooling_type, std::vector ksize, - std::vector strides, std::vector paddings, - const Tensor *in_x, Tensor *out) { +template +void PoolBasic(std::string pooling_type, std::vector ksize, + std::vector strides, std::vector paddings, + const Tensor *in_x, Tensor *out) { if (pooling_type == "max") { - math::PoolFunctor, float> pool2d_forward; - math::MaxPool pool_process; + math::PoolFunctor, T> pool2d_forward; + math::MaxPool pool_process; pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out); } else if (pooling_type == "avg") { - math::PoolFunctor, float> pool2d_forward; - math::AvgPool pool_process; + math::PoolFunctor, T> pool2d_forward; + math::AvgPool pool_process; pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out); } } + template void PoolCompute(const PoolParam ¶m) { const Tensor *in_x = param.Input(); @@ -52,50 +54,67 @@ void PoolCompute(const PoolParam ¶m) { LOG(paddle_mobile::LogLevel::kLOG_ERROR) << "Pool op only supports 2D and 3D input."; } - if (param.isGlobalPooling()) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; ksize[i] = static_cast(in_x->dims()[i + 2]); } } - if (ksize[0] == 3 && ksize[0] == ksize[1]) { - if (pooling_type == "max") { - if (strides[0] == strides[1] && strides[0] == 1 && - paddings[0] == paddings[1] && paddings[1] == 1) { - math::Pool3x3Maxs1p1(in_x, out); + if (in_x->type() == typeid(int8_t)) { + if (pooling_type == "max" && ksize[0] == 3 && ksize[0] == ksize[1]) { + if (strides[0] == strides[1] && strides[0] == 1) { + math::Pool3x3Maxs1_int8(in_x, out, paddings[0], paddings[1]); + } else if (strides[0] == strides[1] && strides[0] == 2) { + math::Pool3x3Maxs2_int8(in_x, out, paddings[0], paddings[1]); } else { - math::Pool3x3Max(strides, paddings, in_x, out); - } - } else if (pooling_type == "avg") { - if (strides[0] == strides[1] && strides[0] == 1 && - paddings[0] == paddings[1] && paddings[1] == 1) { - math::Pool3x3Avgs1p1(in_x, out); - } else { - math::Pool3x3Avg(strides, paddings, in_x, out); + math::Pool3x3Max_int8(strides, paddings, in_x, out); } + } else { + PoolBasic(pooling_type, ksize, strides, paddings, in_x, + out); } + } else { + if (ksize[0] == 3 && ksize[0] == ksize[1]) { + if (pooling_type == "max") { + if (strides[0] == strides[1] && strides[0] == 1 && + paddings[0] == paddings[1] && paddings[1] == 1) { + math::Pool3x3Maxs1p1(in_x, out); + } else { + math::Pool3x3Max(strides, paddings, in_x, out); + } + } else if (pooling_type == "avg") { + if (strides[0] == strides[1] && strides[0] == 1 && + paddings[0] == paddings[1] && paddings[1] == 1) { + math::Pool3x3Avgs1p1(in_x, out); + } else { + math::Pool3x3Avg(strides, paddings, in_x, out); + } + } - } else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 && - strides[0] == strides[1] && paddings[0] == paddings[1] && - paddings[1] == 0) { + } else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 && + strides[0] == strides[1] && paddings[0] == paddings[1] && + paddings[1] == 0) { #if __ARM_NEON #if __aarch64__ - PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); + PoolBasic(pooling_type, ksize, strides, paddings, in_x, + out); #else - /// todo: fix bug in Pool2x2 - if (pooling_type == "max") { - math::Pool2x2Maxs2p0(strides, paddings, in_x, out); - } else if (pooling_type == "avg") { - math::Pool2x2Avgs2p0(strides, paddings, in_x, out); - } + /// todo: fix bug in Pool2x2 + if (pooling_type == "max") { + math::Pool2x2Maxs2p0(strides, paddings, in_x, out); + } else if (pooling_type == "avg") { + math::Pool2x2Avgs2p0(strides, paddings, in_x, out); + } #endif #else - PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); + PoolBasic(pooling_type, ksize, strides, paddings, in_x, + out); #endif // __ARM_NEON - } else { - PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); + } else { + PoolBasic(pooling_type, ksize, strides, paddings, in_x, + out); + } } } diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index c17b2a5e4df0f0ca88da79a9ce55c2ecae0316b5..ae324dbfd383aa2aa93b848710ff5d67c7b4893c 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -2924,6 +2924,7 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, #endif // __ARM_NEON // 32位 float 矩阵乘法 +template <> void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *bias) { @@ -3146,6 +3147,7 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, } // 32位 float 矩阵乘法 +template <> void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *bias) { diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 8498992fcecbcb2c9a773fba874e108c013a04fc..61e957100b35ee2bd16f03ffeec24a8b85339237 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -15,6 +15,10 @@ limitations under the License. */ #pragma once #include #include "common/log.h" +#include "memory/t_malloc.h" +#ifdef _OPENMP +#include +#endif // 矩阵取值运算宏,假设矩阵按行存储 #define A(i, j) A[(i)*lda + (j)] @@ -23,10 +27,12 @@ limitations under the License. */ #if __aarch64__ #define MR_INT8 4 +#define NR_INT8 2 #define MR 6 #define NR 16 #else #define MR_INT8 4 +#define NR_INT8 2 #define MR 6 #define NR 8 #endif @@ -161,11 +167,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *new_bias); */ - // 32位 float 矩阵乘法 - void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, int ldc, bool relu, - float *bias); - // 32位 float 矩阵乘法, 并对结果进行 batchnrom void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, @@ -174,11 +175,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, const float *B, int ldb, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); - // 32位 float 矩阵乘法(openmp 多线程版本) - void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, int ldc, - bool relu, float *bias); - // 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本) void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, @@ -193,52 +189,67 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, // 8 bits int small block inner product void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); + void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc); void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); // 8 bits int inner product - void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, - const int8_t *a, const int8_t *b, int8_t beta, - int32_t *c, int32_t *C, int32_t ldc, bool relu, - int8_t *bias); + template + void InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, Otype *C, + int32_t ldc, bool relu); + template + void InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, Otype *C, + int32_t ldc, bool relu, int32_t *bias); // 8 bits int pack function void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); + void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); + void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer); void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer); void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer); + void PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail, + const int8_t *A, int32_t lda, int8_t *buffer); + void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer); // 8 bits int matrix product - void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C, - int32_t ldc, bool relu, int8_t *bias); - void Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, - int32_t *C, int32_t ldc, bool relu, int8_t *bias); + template + void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const Itype *A, + int32_t lda, const Itype *B, int32_t ldb, float beta, Otype *C, + int32_t ldc, bool relu, Btype *bias); + template + void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, + Otype *C, int32_t ldc, bool relu, int32_t *bias); + template + void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const Itype *A, + int32_t lda, const Itype *B, int32_t ldb, float beta, Otype *C, + int32_t ldc, bool relu, Btype *bias); + template + void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, Otype *C, + int32_t ldc, bool relu, int32_t *bias); // 8 bits int write back - // C = alpha * A * B + beta * C - void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc); // C = A * B void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc); - // C = A * B + C - void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc); - // C = A * B + bias - void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias); - // C = A * B + C, relu(C) - void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc); - // C = A * B + bias, relu(C) - void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias); + // C = A * B + bias, scale * relu(C) + void WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale); + // C = A * B + bias, scale * C + void WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale); private: int MC = 0; @@ -254,10 +265,200 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, // 8 bits int int8_t *packedA_int8; int8_t *packedB_int8; - int32_t *packedC_int8; + int32_t *packedC_int32; int8_t *zero_int8; }; +// 8 bits int matrix product (m*k x k*n) +template +void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, + Otype *C, int32_t ldc, bool relu, int32_t *bias) { + // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) + // L2 cache is 0.5~4 Mib (Contex-A72 cluster) + int32_t L1 = 32 * 1024; + int32_t L2 = 512 * 1024; + + const int32_t k_complete = (k + 15) - ((k + 15) & 15); + KC = k_complete; + MC = L1 / (KC * sizeof(int8_t)); + NC = L2 / (KC * sizeof(int8_t)); + + // make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8 + if (MC == 0) { + MC = MR_INT8; + } else { + int32_t mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; + } + // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; + if (NC == 0) { + NC = NR_INT8; + } else { + int32_t nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; + } + // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); + packedC_int32 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); + zero_int8 = + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); + + memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); + int32_t mc, nc; + for (int32_t j = 0; j < n; j += NC) { + nc = s_min(n - j, NC); + PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8); + for (int32_t i = 0; i < m; i += MC) { + mc = s_min(m - i, MC); + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8); + if (bias == nullptr) { + InnerKernel(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int32, &C(i, j), ldc, relu); + } else { + InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int32, &C(i, j), ldc, relu, bias + i); + } + } + } + + paddle_mobile::memory::Free(packedA_int8); + paddle_mobile::memory::Free(packedB_int8); + paddle_mobile::memory::Free(packedC_int32); + paddle_mobile::memory::Free(zero_int8); +} + +// 8 bits int matrix product (m*k x k*n), omp version +template +void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, + const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, + float beta, Otype *C, int32_t ldc, bool relu, + int32_t *bias) { +#ifdef _OPENMP + int32_t max_threads = omp_get_max_threads(); +#else + int32_t max_threads = 1; +#endif + + int32_t L1 = 64 / max_threads * 1024; + const int32_t k_complete = (k + 15) - ((k + 15) & 15); + KC = k_complete; + zero_int8 = + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); + memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); + if (m > n) { + // 对 A 分块 + MC = L1 / (KC * sizeof(int8_t)); + if (MC == 0) { + MC = MR_INT8; + } else { + int32_t mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; + } + // 补齐 B + NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8; + + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); +#if __aarch64__ + // TODO() +#else + PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); +#endif + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads)); + } else { + // 对 B 分块 + NC = L1 / (KC * sizeof(int8_t)); + if (NC == 0) { + NC = NR_INT8; + } else { + int32_t nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; + } + // 补齐 A + MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8; + + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); +#if __aarch64__ + // TODO() +#else + PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); +#endif + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads)); + } + packedC_int32 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads)); + + if (m > n) { +#pragma omp parallel for + for (int32_t i = 0; i < m; i += MC) { +#ifdef _OPENMP + int32_t local_threads = omp_get_thread_num(); +#else + int32_t local_threads = 0; +#endif + + int32_t mc; + mc = s_min(m - i, MC); + int8_t *local_A = packedA_int8 + MC * KC * local_threads; + int32_t *local_C = packedC_int32 + MC * NC * local_threads; +#if __aarch64__ + // TODO() +#else + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); +#endif + if (bias == nullptr) { + InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C, + &C(i, 0), ldc, relu); + } else { + InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C, + &C(i, 0), ldc, relu, bias + i); + } + } + } else { +#pragma omp parallel for + for (int32_t j = 0; j < n; j += NC) { +#ifdef _OPENMP + int32_t local_threads = omp_get_thread_num(); +#else + int32_t local_threads = 0; +#endif + int32_t nc; + nc = s_min(n - j, NC); + int8_t *local_B = packedB_int8 + KC * NC * local_threads; + int32_t *local_C = packedC_int32 + MC * NC * local_threads; +#if __aarch64__ + // TODO() +#else + PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); +#endif + if (bias == nullptr) { + InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C, + &C(0, j), ldc, relu); + } else { + InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C, + &C(0, j), ldc, relu, bias); + } + } + } + + paddle_mobile::memory::Free(packedA_int8); + paddle_mobile::memory::Free(packedB_int8); + paddle_mobile::memory::Free(packedC_int32); + paddle_mobile::memory::Free(zero_int8); +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index b16db7fe6acf0c3c7fb2902c9fb3f6e3dc81a65f..1659045c3f3868412d53a578447215a91c4b2d7f 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -14,10 +14,11 @@ limitations under the License. */ #include #include "common/log.h" -#include "memory/t_malloc.h" #include "operators/math/gemm.h" #if __ARM_NEON #include +#include + #endif #ifdef _OPENMP #include @@ -30,7 +31,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON #if __aarch64__ -// TODO(wzzju) +// TODO() #else const int8_t *a_ptr, *b_ptr; a_ptr = a; @@ -62,7 +63,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, "pld [%[b_ptr], #128] \n\t" "vld1.s8 {d0-d3}, [%[a_ptr]]! \n\t" // load A 8 cols "vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B first 4 rows - "vmovl.s8 q2, d0 \n\t" // process B first 4 + "vmovl.s8 q2, d0 \n\t" // process B first // rows "vmovl.s8 q3, d8 \n\t" "vmlal.s16 q8, d6, d4[0]\n\t" @@ -241,12 +242,141 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, #endif // __ARM_NEON } +// The core idea of AddDot4x2 function is borrowed from the Google's gemmlowp +// open source library. The address of gemmlowp is +// https://github.com/google/gemmlowp. +void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc) { +#if __ARM_NEON +#if __aarch64__ +// TODO +#else +#define PADDLE_LABEL_LOOP "1" +#define PADDLE_LABEL_AFTER_LOOP "2" + asm volatile( + "lsl %[ldc], %[ldc], #2 \n\t" // sizeof(int32) == 4 + "vldr d0, [%[b], #0] \n\t" + "vmov.s32 q8, #0 \n\t" + "vldr d4, [%[a], #0] \n\t" + "vmov.s32 q9, q8 \n\t" + "vldr d2, [%[b], #16] \n\t" + "vmov.s32 q10, q8 \n\t" + "vldr d6, [%[a], #16] \n\t" + "vmov.s32 q11, q8 \n\t" + "vldr d1, [%[b], #8]\n\t" + "vmov.s32 q12, q8 \n\t" + "vldr d5, [%[a], #8]\n" + "vmov.s32 q13, q8 \n\t" + "vldr d3, [%[b], #24]\n\t" + "vmov.s32 q14, q8 \n\t" + "vldr d7, [%[a], #24]\n" + "vmov.s32 q15, q8 \n\t" + + PADDLE_LABEL_LOOP + ": \n\t" + "vmull.s8 q4, d0, d4 \n\t" // first half + "add %[b], %[b], #32 \n\t" + "vmull.s8 q5, d2, d4 \n\t" + "vldr d4, [%[a], #32] \n\t" + "vmull.s8 q6, d0, d6 \n\t" + "vmull.s8 q7, d2, d6 \n\t" + "vldr d6, [%[a], #48] \n\t" + + "vmlal.s8 q4, d1, d5 \n\t" // second half + "vmlal.s8 q5, d3, d5 \n\t" + "vldr d5, [%[a], #40] \n\t" + "vmlal.s8 q6, d1, d7 \n\t" + "vmlal.s8 q7, d3, d7 \n\t" + "vldr d7, [%[a], #56] \n\t" + + "vpadal.s16 q8, q4 \n\t" // pairwise-add + "add %[a], %[a], #64 \n\t" + "vpadal.s16 q9, q5 \n\t" + "subs %[k], %[k], #16 \n\t" + "vpadal.s16 q10, q6 \n\t" + "vpadal.s16 q11, q7 \n\t" + + "beq " PADDLE_LABEL_AFTER_LOOP + "f \n\t" + + "vmull.s8 q4, d0, d4 \n\t" // first half + "vmull.s8 q5, d2, d4 \n\t" + "vldr d4, [%[a], #0] \n\t" + "vmull.s8 q6, d0, d6 \n\t" + "vldr d0, [%[b], #0] \n\t" + "vmull.s8 q7, d2, d6 \n\t" + "vldr d2, [%[b], #16] \n\t" + + "vmlal.s8 q4, d1, d5 \n\t" // second half + "vldr d6, [%[a], #16] \n\t" + "vmlal.s8 q5, d3, d5 \n\t" + "vldr d5, [%[a], #8] \n\t" + "vmlal.s8 q6, d1, d7 \n\t" + "vldr d1, [%[b], #8] \n\t" + "vmlal.s8 q7, d3, d7 \n\t" + "vldr d3, [%[b], #24] \n\t" + + "vpadal.s16 q12, q4 \n\t" // pairwise-add + "vldr d7, [%[a], #24] \n\t" + "vpadal.s16 q13, q5 \n\t" + "vpadal.s16 q14, q6 \n\t" + "vpadal.s16 q15, q7 \n\t" + + "b " PADDLE_LABEL_LOOP "b \n\t" + + PADDLE_LABEL_AFTER_LOOP + ": \n\t" + "vmull.s8 q4, d0, d4 \n\t" // first half + "vmull.s8 q5, d2, d4 \n\t" + "vmull.s8 q6, d0, d6 \n\t" + "vmull.s8 q7, d2, d6 \n\t" + + "vmlal.s8 q4, d1, d5 \n\t" // second half + "vmlal.s8 q5, d3, d5 \n\t" + "vmlal.s8 q6, d1, d7 \n\t" + "vmlal.s8 q7, d3, d7 \n\t" + + "vpadal.s16 q12, q4 \n\t" // pairwise-add + "vpadal.s16 q13, q5 \n\t" + "vpadal.s16 q14, q6 \n\t" + "vpadal.s16 q15, q7 \n\t" + + "vpadd.s32 d0, d16, d17 \n\t" // reduce to int32 + "vpadd.s32 d1, d18, d19 \n\t" + "vpadd.s32 d2, d20, d21 \n\t" + "vpadd.s32 d3, d22, d23 \n\t" + "vpadd.s32 d4, d24, d25 \n\t" + "vpadd.s32 d5, d26, d27 \n\t" + "vpadd.s32 d6, d28, d29 \n\t" + "vpadd.s32 d7, d30, d31 \n\t" + + "vpadd.s32 d8, d0, d1 \n\t" // reduce to int32 again + "vpadd.s32 d9, d2, d3 \n\t" + "vpadd.s32 d10, d4, d5 \n\t" + "vpadd.s32 d11, d6, d7 \n\t" + + "vst1.32 {d8}, [%[c]], %[ldc] \n\t" + "vst1.32 {d9}, [%[c]], %[ldc] \n\t" + "vst1.32 {d10}, [%[c]], %[ldc] \n\t" + "vst1.32 {d11}, [%[c]] \n\t" + + : [k] "+r"(k), [a] "+r"(a), [b] "+r"(b), [c] "+r"(c) + : [ldc] "r"(ldc) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#undef PADDLE_LABEL_AFTER_LOOP +#undef PADDLE_LABEL_LOOP + +#endif // __aarch64__ +#endif // __ARM_NEON +} + // 8 bits int small block inner product void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON #if __aarch64__ -// TODO(wzzju) +// TODO #else const int8_t *a_ptr, *b_ptr; a_ptr = a; @@ -539,51 +669,225 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, } // 8 bits int inner product -void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, - const int8_t *a, const int8_t *b, int8_t beta, - int32_t *c, int32_t *C, int32_t ldc, bool relu, - int8_t *bias) { +template <> +void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, int8_t *C, + int32_t ldc, bool relu) {} +template <> +void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, int32_t *C, + int32_t ldc, bool relu) { #pragma omp parallel for - for (int32_t j = 0; j < nc; j += NR) { + for (int32_t j = 0; j < nc; j += NR_INT8) { for (int32_t i = 0; i < mc; i += MR_INT8) { #if __aarch64__ - // TODO(wzzju) + // TODO #else // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); - AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); #endif // __aarch64__ } } - if (alpha != 1) { - WriteWithAlphaBeta(mc, nc, c, C, ldc); - return; - } - if (beta == 0) { + if (!relu) { WriteBasic(mc, nc, c, C, ldc); return; } - if (beta == 1 && !relu) { - if (bias == nullptr) { - WriteWithAdd(mc, nc, c, C, ldc); - } else { - WriteWithAddV1(mc, nc, c, C, ldc, bias); +} + +template <> +void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, + const int8_t *a, const int8_t *b, float beta, + int32_t *c, int8_t *C, int32_t ldc, bool relu, + int32_t *bias) { +#pragma omp parallel for + for (int32_t j = 0; j < nc; j += NR_INT8) { + for (int32_t i = 0; i < mc; i += MR_INT8) { +#if __aarch64__ + // TODO +#else + // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#endif // __aarch64__ } + } + if (relu) { + WriteWithAddReluScale(mc, nc, c, C, ldc, bias, alpha); return; + } else { + WriteWithAddScale(mc, nc, c, C, ldc, bias, alpha); } - if (beta == 1 && relu) { - if (bias == nullptr) { - WriteWithAddRelu(mc, nc, c, C, ldc); - } else { - WriteWithAddReluV1(mc, nc, c, C, ldc, bias); +} + +template <> +void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, + const int8_t *a, const int8_t *b, float beta, + int32_t *c, int32_t *C, int32_t ldc, bool relu, + int32_t *bias) {} + +// 8 bits int PackMatrixA_4r +void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, + const int8_t *A, int32_t lda, int8_t *buffer) { + const int32_t i_length = m - m_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; + + for (int32_t i = 0; i < i_length; i += 4) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + int8_t *local_buffer = buffer + i * KC; + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + int8_t *local_buffer = buffer + i_length * KC; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } } - return; } } + // 8 bits int PackMatrixA_4r void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { const int8_t *a0, *a1, *a2, *a3; - for (int32_t i = 0; i < m - m_tail; i += MR_INT8) { + for (int32_t i = 0; i < m - m_tail; i += 4) { a0 = A + i * lda; a1 = A + (i + 1) * lda; a2 = A + (i + 2) * lda; @@ -625,7 +929,7 @@ void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { const int32_t i_length = m - m_tail; - for (int32_t i = 0; i < i_length; i += MR_INT8) { + for (int32_t i = 0; i < i_length; i += 6) { const int8_t *a0 = A + i * lda; const int8_t *a1 = A + (i + 1) * lda; const int8_t *a2 = A + (i + 2) * lda; @@ -676,17 +980,85 @@ void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, } } +// 8 bits int PackMatrixB +void Gemm::PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; + for (int32_t j = 0; j < j_length; j += 2) { + int8_t *local_buffer = buffer + j * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j); + const int8_t *b1 = &B((i << 4), j + 1); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b1; + b1 += ldb; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j); + const int8_t *b1 = &B((k_count << 4), j + 1); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b1; + b1 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j_length); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = 0; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j_length); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } +} + // 8 bits int PackMatrixB void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer) { const int32_t j_length = n - n_tail; - for (int32_t j = 0; j < j_length; j += NR) { + for (int32_t j = 0; j < j_length; j += 8) { int8_t *local_buffer = buffer + j * k; for (int32_t i = 0; i < k; ++i) { const int8_t *b0 = &B(i, j); #if __ARM_NEON #if __aarch64__ - // TODO(wzzju) + // TODO #else asm volatile( // "pld [%[b0]] \n\t" @@ -715,94 +1087,27 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, for (int32_t j = j_length; j < n; ++j) { *local_buffer++ = *b0++; } - for (int32_t j = n; j < j_length + NR; ++j) { + for (int32_t j = n; j < j_length + 8; ++j) { *local_buffer++ = 0; } } } } -// 8 bits int matrix product (m*k x k*n) -void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, - int32_t *C, int32_t ldc, bool relu, int8_t *bias) { - // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) - // L2 cache is 0.5~4 Mib (Contex-A72 cluster) - int32_t L1 = 32 * 1024; - int32_t L2 = 512 * 1024; - - KC = k; - MC = L1 / (KC * sizeof(int8_t)); - NC = L2 / (KC * sizeof(int8_t)); - - // make sure MC is multiple of MR_INT8, and NC is multiple of NR - if (MC == 0) { - MC = MR_INT8; - } else { - int32_t mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; - } - // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; - if (NC == 0) { - NC = NR; - } else { - int32_t nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; - } - // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; - packedA_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); - packedB_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); - packedC_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); - zero_int8 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); - - memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); - int32_t mc, nc; - for (int32_t j = 0; j < n; j += NC) { - nc = s_min(n - j, NC); - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); - for (int32_t i = 0; i < m; i += MC) { - mc = s_min(m - i, MC); - // PackMatrixA_6r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); - PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); - if (bias == nullptr) { - InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, - packedC_int8, &C(i, j), ldc, relu, nullptr); - } else { - InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, - packedC_int8, &C(i, j), ldc, relu, bias + i); - } - } - } - - paddle_mobile::memory::Free(packedA_int8); - paddle_mobile::memory::Free(packedB_int8); - paddle_mobile::memory::Free(packedC_int8); - paddle_mobile::memory::Free(zero_int8); -} - // 8 bits int write back -// C = alpha * A * B + beta * C -void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc) {} -// C = A * B, 8位 int32_t +// C = A * B void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc) { #if __ARM_NEON #if __aarch64__ -// TODO(wzzju) +// TODO #else int32_t nc1 = nc >> 4; int32_t _nc1 = nc & 15; int32_t step = sizeof(int32_t) * ldc; int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 4)); int32_t volatile m = mc; - + int32_t volatile n = nc1; int32_t *volatile c_ptr, *volatile C_ptr; int32_t *C0, *c0; c_ptr = c; @@ -836,7 +1141,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, "end_mc_%=: \n\t" : - : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n), [step] "r"(step), [step1] "r"(step1) : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); } @@ -854,20 +1159,254 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, #endif // __ARM_NEON } -// C = A * B + C -void Gemm::WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc) {} +// C = A * B + bias, scale * C +void Gemm::WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale) { +#if __ARM_NEON +#if __aarch64__ +// TODO +#else + int32_t zero = 0; + int8_t narrow = -128; + int32_t nc1 = nc >> 3; + int32_t _nc1 = nc & 7; + int32_t step = sizeof(int8_t) * ldc; + int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 3)); + int32_t volatile m = mc; + int32_t volatile n = nc1; + int32_t *volatile c_ptr, *volatile bias_ptr; + int8_t *volatile C_ptr; + c_ptr = c; + C_ptr = C; + bias_ptr = bias; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "vdup.8 d24, %[narrow] \n\t" + "loop_mc_%=: \n\t" + "vld1.32 {d26[0]}, [%[bias_ptr]]!\n\t" + "vdup.32 q13, d26[0] \n\t" + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vqadd.s32 q1, q1, q13 \n\t" + "vcvt.f32.s32 q2, q0 \n\t" + "vcvt.f32.s32 q3, q1 \n\t" + "vmul.f32 q2, q2, q15 \n\t" + "vmul.f32 q3, q3, q15 \n\t" + "vcvt.s32.f32 q4, q2 \n\t" + "vcvt.s32.f32 q5, q3 \n\t" + "vqmovn.s32 d12, q4 \n\t" + "vqmovn.s32 d13, q5 \n\t" + "vqmovn.s16 d14, q6 \n\t" + "vceq.s8 d15, d14, d24 \n\t" + "vsub.s8 d14, d14, d15 \n\t" + "vst1.8 {d14}, [r6]! \n\t" + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" + + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n), + [step] "r"(step), [step1] "r"(step1), [bias_ptr] "r"(bias_ptr), + [scale] "r"(scale), [zero] "r"(zero), [narrow] "r"(narrow) + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q12", "q13", "q14", "q15"); + } + + int32_t nc_left; + int32_t *c0; + int8_t *C0; + int32_t bias_v; + if (_nc1 != 0) { + for (int32_t i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 8 + i * ldc; + c0 = c_ptr + nc1 * 8 + i * NC; + bias_v = *(bias_ptr + i); + nc_left = _nc1; + asm volatile( + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "vdup.8 d24, %[narrow] \n\t" + "vdup.32 q13, %[bias_v] \n\t" + "cmp %[_nc1], #4 \n\t" + "blt less_four_%= \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "vceq.s8 d9, d8, d24 \n\t" + "vsub.s8 d8, d8, d9 \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vst1.8 {d8[1]}, [%[C0]]! \n\t" + "vst1.8 {d8[2]}, [%[C0]]! \n\t" + "vst1.8 {d8[3]}, [%[C0]]! \n\t" + "subs %[_nc1], %[_nc1], #4 \n\t" + "beq process_over_%= \n\t" + "less_four_%=: \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "vceq.s8 d9, d8, d24 \n\t" + "vsub.s8 d8, d8, d9 \n\t" + "loop_save_%=: \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vext.8 d8, d8, d8, #1 \n\t" + "subs %[_nc1], %[_nc1], #1 \n\t" + "bgt loop_save_%= \n\t" + "process_over_%=: \n\t" + : + : [_nc1] "r"(nc_left), [C0] "r"(C0), [c0] "r"(c0), + [bias_v] "r"(bias_v), [scale] "r"(scale), [zero] "r"(zero), + [narrow] "r"(narrow) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q12", "q13", "q14", + "q15"); + } + } +#endif // __aarch64__ +#endif // __ARM_NEON +} + +// C = A * B + bias, scale * relu(C) +void Gemm::WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale) { +#if __ARM_NEON +#if __aarch64__ +// TODO +#else + int32_t zero = 0; + int32_t nc1 = nc >> 3; + int32_t _nc1 = nc & 7; + int32_t step = sizeof(int8_t) * ldc; + int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 3)); + int32_t volatile m = mc; + int32_t volatile n = nc1; + int32_t *volatile c_ptr, *volatile bias_ptr; + int8_t *volatile C_ptr; + c_ptr = c; + C_ptr = C; + bias_ptr = bias; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "loop_mc_%=: \n\t" + "vld1.32 {d26[0]}, [%[bias_ptr]]!\n\t" + "vdup.32 q13, d26[0] \n\t" + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vqadd.s32 q1, q1, q13 \n\t" + "vmax.s32 q0, q0, q14 \n\t" + "vmax.s32 q1, q1, q14 \n\t" + "vcvt.f32.s32 q2, q0 \n\t" + "vcvt.f32.s32 q3, q1 \n\t" + "vmul.f32 q2, q2, q15 \n\t" + "vmul.f32 q3, q3, q15 \n\t" + "vcvt.s32.f32 q4, q2 \n\t" + "vcvt.s32.f32 q5, q3 \n\t" + "vqmovn.s32 d12, q4 \n\t" + "vqmovn.s32 d13, q5 \n\t" + "vqmovn.s16 d14, q6 \n\t" + "vst1.8 {d14}, [r6]! \n\t" + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" -// C = A * B + bias -void Gemm::WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias) {} -// C = A * B + C, relu(C) -void Gemm::WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc) {} + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n), + [step] "r"(step), [step1] "r"(step1), [bias_ptr] "r"(bias_ptr), + [scale] "r"(scale), [zero] "r"(zero) + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q13", "q14", "q15"); + } -// C = A * B + bias, relu(C) -void Gemm::WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias) {} + int32_t nc_left; + int32_t *c0; + int8_t *C0; + int32_t bias_v; + if (_nc1 != 0) { + for (int32_t i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 8 + i * ldc; + c0 = c_ptr + nc1 * 8 + i * NC; + bias_v = *(bias_ptr + i); + nc_left = _nc1; + asm volatile( + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "vdup.32 q13, %[bias_v] \n\t" + "cmp %[_nc1], #4 \n\t" + "blt less_four_%= \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vmax.s32 q0, q0, q14 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vst1.8 {d8[1]}, [%[C0]]! \n\t" + "vst1.8 {d8[2]}, [%[C0]]! \n\t" + "vst1.8 {d8[3]}, [%[C0]]! \n\t" + "subs %[_nc1], %[_nc1], #4 \n\t" + "beq process_over_%= \n\t" + "less_four_%=: \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vmax.s32 q0, q0, q14 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "loop_save_%=: \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vext.8 d8, d8, d8, #1 \n\t" + "subs %[_nc1], %[_nc1], #1 \n\t" + "bgt loop_save_%= \n\t" + "process_over_%=: \n\t" + : + : [_nc1] "r"(nc_left), [C0] "r"(C0), [c0] "r"(c0), + [bias_v] "r"(bias_v), [scale] "r"(scale), [zero] "r"(zero) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q13", "q14", "q15"); + } + } +#endif // __aarch64__ +#endif // __ARM_NEON +} } // namespace math } // namespace operators diff --git a/src/operators/math/gemm_omp_int8.cpp b/src/operators/math/gemm_omp_int8.cpp index 21256cccfcc6dcc647f34a2129616b70804d398f..61f0be418f27e6264c310fcc58b9a652d5f9805e 100644 --- a/src/operators/math/gemm_omp_int8.cpp +++ b/src/operators/math/gemm_omp_int8.cpp @@ -27,130 +27,17 @@ namespace paddle_mobile { namespace operators { namespace math { -// 8 bits int matrix product (m*k x k*n) -void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, - const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, - int8_t beta, int32_t *C, int32_t ldc, bool relu, - int8_t *bias) { -#ifdef _OPENMP - int32_t max_threads = omp_get_max_threads(); -#else - int32_t max_threads = 1; -#endif - - int32_t L1 = 64 / max_threads * 1024; - KC = k; - zero_int8 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); - memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); - if (m > n) { - // 对 A 分块 - MC = L1 / (KC * sizeof(int8_t)); - if (MC == 0) { - MC = MR_INT8; - } else { - int32_t mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; - } - // 补齐 B - NC = (n + NR - 1) / NR * NR; - - packedB_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); -#if __aarch64__ - // TODO(wzzju) -#else - PackMatrixB_omp_8c(KC, n, n % NR, B, ldb, packedB_int8); -#endif - packedA_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads)); - } else { - // 对 B 分块 - NC = L1 / (KC * sizeof(int8_t)); - if (NC == 0) { - NC = NR; - } else { - int32_t nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; - } - // 补齐 A - MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8; - - packedA_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); -#if __aarch64__ - // TODO(wzzju) -#else - PackMatrixA_omp_4r(m, KC, m % MR_INT8, A, lda, packedA_int8); -#endif - packedB_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads)); - } - packedC_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads)); - - if (m > n) { -#pragma omp parallel for - for (int32_t i = 0; i < m; i += MC) { -#ifdef _OPENMP - int32_t local_threads = omp_get_thread_num(); -#else - int32_t local_threads = 0; -#endif - - int32_t mc; - mc = s_min(m - i, MC); - int8_t *local_A = packedA_int8 + MC * KC * local_threads; - int32_t *local_C = packedC_int8 + MC * NC * local_threads; -#if __aarch64__ - // TODO(wzzju) -#else - PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, local_A); -#endif - InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C, - &C(i, 0), ldc, relu, bias + i); - } - } else { -#pragma omp parallel for - for (int32_t j = 0; j < n; j += NC) { -#ifdef _OPENMP - int32_t local_threads = omp_get_thread_num(); -#else - int32_t local_threads = 0; -#endif - int32_t nc; - nc = s_min(n - j, NC); - int8_t *local_B = packedB_int8 + KC * NC * local_threads; - int32_t *local_C = packedC_int8 + MC * NC * local_threads; -#if __aarch64__ - // TODO(wzzju) -#else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, local_B); -#endif - InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C, - &C(0, j), ldc, relu, bias); - } - } - - paddle_mobile::memory::Free(packedA_int8); - paddle_mobile::memory::Free(packedB_int8); - paddle_mobile::memory::Free(packedC_int8); - paddle_mobile::memory::Free(zero_int8); -} - void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer) { const int32_t j_length = n - n_tail; #pragma omp parallel for - for (int32_t j = 0; j < j_length; j += NR) { + for (int32_t j = 0; j < j_length; j += 8) { int8_t *local_buffer = buffer + j * k; for (int32_t i = 0; i < k; ++i) { const int8_t *b0 = &B(i, j); #if __ARM_NEON #if __aarch64__ - // TODO(wzzju) + // TODO #else asm volatile( // "pld [%[b0]] \n\t" @@ -179,7 +66,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, for (int32_t j = j_length; j < n; ++j) { *local_buffer++ = *b0++; } - for (int32_t j = n; j < j_length + NR; ++j) { + for (int32_t j = n; j < j_length + 8; ++j) { *local_buffer++ = 0; } } @@ -188,9 +75,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { - const int i_length = m - m_tail; + const int32_t i_length = m - m_tail; #pragma omp parallel for - for (int32_t i = 0; i < i_length; i += MR_INT8) { + for (int32_t i = 0; i < i_length; i += 4) { const int8_t *a0 = A + i * lda; const int8_t *a1 = A + (i + 1) * lda; const int8_t *a2 = A + (i + 2) * lda; @@ -221,7 +108,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, default: break; } - for (int j = 0; j < k; ++j) { + for (int32_t j = 0; j < k; ++j) { *local_buffer++ = *a0++; *local_buffer++ = *a1++; *local_buffer++ = *a2++; @@ -230,6 +117,232 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, } } +// 8 bits int PackMatrixA_4r +void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail, + const int8_t *A, int32_t lda, int8_t *buffer) { + const int32_t i_length = m - m_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; +#pragma omp parallel for + for (int32_t i = 0; i < i_length; i += 4) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + int8_t *local_buffer = buffer + i * KC; + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + int8_t *local_buffer = buffer + i_length * KC; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } +} + +// 8 bits int PackMatrixB +void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; +#pragma omp parallel for + for (int32_t j = 0; j < j_length; j += 2) { + int8_t *local_buffer = buffer + j * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j); + const int8_t *b1 = &B((i << 4), j + 1); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b1; + b1 += ldb; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j); + const int8_t *b1 = &B((k_count << 4), j + 1); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b1; + b1 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j_length); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = 0; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j_length); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gru_compute.cpp b/src/operators/math/gru_compute.cpp index 9e77f572c53bc2ba9be57f5edbd2b4bf85f5305e..bbf1b01a21a980293f3cfe255885e7127aeb208e 100644 --- a/src/operators/math/gru_compute.cpp +++ b/src/operators/math/gru_compute.cpp @@ -34,12 +34,12 @@ struct GRUUnitFunctor { gemm.Sgemm_omp(batch_size, frame_size * 2, frame_size, 1, value.prev_out_value, frame_size, value.gate_weight, frame_size * 2, 1, value.gate_value, frame_size * 3, false, - nullptr); + static_cast(nullptr)); #else gemm.Sgemm(batch_size, frame_size * 2, frame_size, 1, value.prev_out_value, frame_size, value.gate_weight, frame_size * 2, 1, value.gate_value, frame_size * 3, false, - nullptr); + static_cast(nullptr)); #endif } @@ -51,12 +51,12 @@ struct GRUUnitFunctor { gemm.Sgemm_omp(batch_size, frame_size, frame_size, 1, value.reset_output_value, frame_size, value.state_weight, frame_size, 1, value.gate_value + frame_size * 2, - frame_size * 3, false, nullptr); + frame_size * 3, false, static_cast(nullptr)); #else gemm.Sgemm(batch_size, frame_size, frame_size, 1, value.reset_output_value, frame_size, value.state_weight, frame_size, 1, value.gate_value + frame_size * 2, - frame_size * 3, false, nullptr); + frame_size * 3, false, static_cast(nullptr)); #endif } diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index b91242c1868398e4541c3727567a905e5b0c8714..c58e8035940c65646851961bc2b9d12307f37e7a 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -28,7 +28,13 @@ template void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu = false, - T *bias = nullptr); + float *bias = nullptr); + +template +void matmul(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, T alpha, + framework::Tensor *matrix_out, T beta, bool relu = false, + S *bias = nullptr); template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp index e02824b290ebc0080613e2ae2365626d79576c9e..a407a2915dbe6c17537b85371b9426acfd4a1b2c 100644 --- a/src/operators/math/math_function_int8.cpp +++ b/src/operators/math/math_function_int8.cpp @@ -20,11 +20,12 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { + template <> -void matmul(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, - int8_t alpha, framework::Tensor *matrix_out, int8_t beta, - bool relu, int8_t *bias) { +void matmul(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, float alpha, + framework::Tensor *matrix_out, float beta, bool relu, + int32_t *bias) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -52,21 +53,43 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, } #ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + if (bias != nullptr) { + gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } else { + gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } #else - gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + if (bias != nullptr) { + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } else { + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } #endif } else { #ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + if (bias != nullptr) { + gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } else { + gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } #else - gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), N, - relu, bias); + if (bias != nullptr) { + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), + N, relu, bias); + } else { + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), + N, relu, bias); + } #endif } } diff --git a/src/operators/math/pool_3x3.cpp b/src/operators/math/pool_3x3.cpp index dadb5a67cf6dda531b15783feafe5cee370e109a..a2b84a5b143c3cbd2db41223636505aa1d43a7f7 100644 --- a/src/operators/math/pool_3x3.cpp +++ b/src/operators/math/pool_3x3.cpp @@ -38,6 +38,7 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { const int input_width = static_cast(input->dims()[3]); const int output_height = static_cast(output->dims()[2]); const int output_width = static_cast(output->dims()[3]); + output->mutable_data(); const int hxw = input_height * input_width; @@ -472,7 +473,7 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { const int inputdata_channel_stride = h_in * w_in; const int input_batch_stride = output_channels * inputdata_channel_stride; const int output_batch_stride = output_channels * outputdata_channel_stride; - float *out_data = output->data(); + float *out_data = output->mutable_data(); const float *input_data = input->data(); for (int k = 0; k < batch_size; ++k) { #pragma omp parallel for diff --git a/src/operators/math/pool_3x3.h b/src/operators/math/pool_3x3.h index ac1eb16a4c0e077c625267545767b8f29144b8f1..a13cb6ab374ba22050d17e9c9fb5a1e94f857fb2 100644 --- a/src/operators/math/pool_3x3.h +++ b/src/operators/math/pool_3x3.h @@ -28,15 +28,21 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { -using framework::Tensor; -using std::vector; -void Pool3x3Avgs1p1(const Tensor *input, Tensor *output); -void Pool3x3Maxs1p1(const Tensor *input, Tensor *output); -void Pool3x3Max(vector strides, vector paddings, const Tensor *input, - Tensor *output); - -void Pool3x3Avg(vector strides, vector paddings, const Tensor *in_x, - Tensor *out); +void Pool3x3Avgs1p1(const framework::Tensor *input, framework::Tensor *output); +void Pool3x3Maxs1p1(const framework::Tensor *input, framework::Tensor *output); +void Pool3x3Max(std::vector strides, std::vector paddings, + const framework::Tensor *input, framework::Tensor *output); + +void Pool3x3Avg(std::vector strides, std::vector paddings, + const framework::Tensor *in_x, framework::Tensor *out); + +void Pool3x3Maxs1_int8(const framework::Tensor *input, + framework::Tensor *output, int32_t pad_h, int32_t pad_w); +void Pool3x3Maxs2_int8(const framework::Tensor *input, + framework::Tensor *output, int32_t pad_h, int32_t pad_w); +void Pool3x3Max_int8(const std::vector &strides, + const std::vector &paddings, + const framework::Tensor *input, framework::Tensor *output); } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/pool_3x3_int8.cpp b/src/operators/math/pool_3x3_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d344c489ae5da96ba5d82986911ca34865aadebc --- /dev/null +++ b/src/operators/math/pool_3x3_int8.cpp @@ -0,0 +1,564 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef POOL_OP +#ifdef _OPENMP +#include +#endif +#include "framework/tensor.h" +#include "operators/math/pool_3x3.h" +#if __ARM_NEON +#include +#endif // __ARM_NEON +#include +#include +namespace paddle_mobile { +namespace operators { +namespace math { +using framework::Tensor; +using std::max; +using std::min; +using std::vector; +template +static void make_paddings(const Tensor *input, Tensor *padded_input, + int32_t top, int32_t bottom, int32_t left, + int32_t right, T value) { + const int32_t batch_size = input->dims()[0]; + const int32_t c_in = input->dims()[1]; + const int32_t h_in = input->dims()[2]; + const int32_t w_in = input->dims()[3]; + const int32_t h_padded = h_in + top + bottom; + const int32_t w_padded = w_in + left + right; + padded_input->Resize({batch_size, c_in, h_padded, w_padded}); + T *padded_input_data = padded_input->mutable_data(); + const T *input_data = input->data(); + const int32_t input_channel_stride = h_in * w_in; + const int32_t input_batch_stride = c_in * input_channel_stride; + const int32_t padded_channel_stride = h_padded * w_padded; + const int32_t padded_batch_stride = c_in * padded_channel_stride; + for (int i = 0; i < batch_size; ++i) { +#pragma omp parallel for + for (int j = 0; j < c_in; ++j) { + const T *img_in = input_data + j * input_channel_stride; + T *img_padded = padded_input_data + j * padded_channel_stride; + int k = 0; + for (; k < top; ++k) { + for (int l = 0; l < w_padded; ++l) { + img_padded[l] = value; + } + img_padded += w_padded; + } + for (; k < top + h_in; ++k) { + int l = 0; + for (; l < left; ++l) { + img_padded[l] = value; + } + memcpy(img_padded + left, img_in, w_in * sizeof(T)); + l += w_in; + img_in += w_in; + for (; l < w_padded; ++l) { + img_padded[l] = value; + } + img_padded += w_padded; + } + for (; k < h_padded; ++k) { + for (int l = 0; l < w_padded; ++l) { + img_padded[l] = value; + } + img_padded += w_padded; + } + } + input_data += input_batch_stride; + padded_input_data += padded_batch_stride; + } + // input_data = input->data(); + // std::cout << "+++++++++++++++++++Origin begin++++++++++++++++++++" + // << std::endl; + // for (int i = 0; i < 1; ++i) { + // for (int j = 0; j < 1; ++j) { + // const T *img_in = input_data + j * input_channel_stride; + // for (int k = 0; k < h_in; ++k) { + // for (int l = 0; l < w_in; ++l) { + // std::cout << (int32_t)*img_in << "\t"; + // img_in++; + // } + // std::cout << std::endl; + // } + // } + // input_data += input_batch_stride; + // } + // std::cout << "+++++++++++++++++++Origin end++++++++++++++++++++" << + // std::endl; + // + // padded_input_data = padded_input->data(); + // std::cout << "******************Padding begin**********************" + // << std::endl; + // for (int i = 0; i < 1; ++i) { + // for (int j = 0; j < 1; ++j) { + // T *img_padded = padded_input_data + j * padded_channel_stride; + // for (int k = 0; k < h_padded; ++k) { + // for (int l = 0; l < w_padded; ++l) { + // std::cout << (int32_t)*img_padded << "\t"; + // img_padded++; + // } + // std::cout << std::endl; + // } + // } + // padded_input_data += padded_batch_stride; + // } + // std::cout << "******************Padding end**********************" + // << std::endl; +} +void Pool3x3Maxs1_int8(const Tensor *input, Tensor *output, int32_t pad_h, + int32_t pad_w) { + Tensor padded_input; + if (pad_h != 0 && pad_w != 0) { + int8_t value = -SCHAR_MAX; + make_paddings(input, &padded_input, pad_h, pad_h, pad_w, pad_w, value); + input = &padded_input; + } + const int32_t batch_size = input->dims()[0]; + const int32_t h_in = input->dims()[2]; + const int32_t w_in = input->dims()[3]; + const int8_t *input_data = input->data(); + const int32_t output_channels = output->dims()[1]; + const int32_t h_out = output->dims()[2]; + const int32_t w_out = output->dims()[3]; + int8_t *output_data = output->mutable_data(); + const int32_t outputdata_channel_stride = h_out * w_out; + const int32_t inputdata_channel_stride = h_in * w_in; + const int32_t input_batch_stride = output_channels * inputdata_channel_stride; + const int32_t output_batch_stride = + output_channels * outputdata_channel_stride; + // std::cout << "h_out = " << h_out << ", w_out=" << w_out << std::endl; + for (int i = 0; i < batch_size; ++i) { +#pragma omp parallel for + for (int j = 0; j < output_channels; ++j) { + const int8_t *img_in = input_data + j * inputdata_channel_stride; + int8_t *img_out = output_data + j * outputdata_channel_stride; + for (int k = 0; k < h_out; ++k) { + const int8_t *row0 = img_in + k * w_in; + const int8_t *row1 = img_in + (k + 1) * w_in; + const int8_t *row2 = img_in + (k + 2) * w_in; +#if __ARM_NEON + int32_t nw = w_out >> 4; + int32_t left_w = w_out & 0xf; + int32_t nw1 = left_w >> 3; + int32_t left_w1 = left_w & 0x7; +#if __aarch64__ + // TODO +#else + if (nw > 0) { +#define LOOP_LABEL "1" + // result: q15 + asm volatile( + "vld1.8 {q0}, [%[row0]]! \n\t" // q0=0-15 + "vld1.8 {q2}, [%[row1]]! \n\t" + "vld1.8 {q4}, [%[row2]]! \n\t" + + LOOP_LABEL + ": \n\t" + "vld1.8 {q1}, [%[row0]]! \n\t" // q1=16-31 + "vext.8 q6, q0, q1, #1 \n\t" + "vext.8 q7, q0, q1, #2 \n\t" + "vld1.8 {q3}, [%[row1]]! \n\t" + "vmax.s8 q15, q0, q6 \n\t" + "vmax.s8 q15, q15, q7 \n\t" + "vext.8 q6, q2, q3, #1 \n\t" + "vext.8 q7, q2, q3, #2 \n\t" + "vld1.8 {q5}, [%[row2]]! \n\t" + "vmax.s8 q14, q2, q6 \n\t" + "vmax.s8 q14, q14, q7 \n\t" + "vext.8 q6, q4, q5, #1 \n\t" + "vext.8 q7, q4, q5, #2 \n\t" + "vmax.s8 q13, q4, q6 \n\t" + "vmax.s8 q13, q13, q7 \n\t" + "vmax.s8 q15, q15, q14 \n\t" + "vmax.s8 q15, q15, q13 \n\t" + "vmov.s8 q0, q1 \n\t" + "vmov.s8 q2, q3 \n\t" + "vmov.s8 q4, q5 \n\t" + "vst1.8 {q15}, [%[img_out]]! \n\t" + "subs %[nw], #1 \n\t" + "bne " LOOP_LABEL + "b \n\t" + "sub %[row0], #16 \n\t" + "sub %[row1], #16 \n\t" + "sub %[row2], #16 \n\t" + : [nw] "+r"(nw), [row0] "+r"(row0), [row1] "+r"(row1), + [row2] "+r"(row2), [img_out] "+r"(img_out) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q13", "q14", "q15"); +#undef LOOP_LABEL + } + if (nw1 > 0 || left_w1 > 0) { +#define PADDLE_LABEL_LESS8 "1" +#define PADDLE_LABEL_LESS8_SAVE "2" +#define PADDLE_LABEL_OVER "3" + // result: d15 + asm volatile( + "vld1.8 {d0}, [%[row0]]! \n\t" // d0=0-8 + "vld1.8 {d2}, [%[row1]]! \n\t" + "vld1.8 {d4}, [%[row2]]! \n\t" + "mov r0, #1 \n\t" + "cmp %[nw1], #0 \n\t" + "beq " PADDLE_LABEL_LESS8 + "f\n\t" + "vld1.8 {d1}, [%[row0]]! \n\t" // d1=9-15 + "vext.8 d6, d0, d1, #1 \n\t" + "vext.8 d7, d0, d1, #2 \n\t" + "vld1.8 {d3}, [%[row1]]! \n\t" + "vmax.s8 d15, d0, d6 \n\t" + "vmax.s8 d15, d15, d7 \n\t" + "vext.8 d6, d2, d3, #1 \n\t" + "vext.8 d7, d2, d3, #2 \n\t" + "vld1.8 {d5}, [%[row2]]! \n\t" + "vmax.s8 d14, d2, d6 \n\t" + "vmax.s8 d14, d14, d7 \n\t" + "vext.8 d6, d4, d5, #1 \n\t" + "vext.8 d7, d4, d5, #2 \n\t" + "vmax.s8 d13, d4, d6 \n\t" + "vmax.s8 d13, d13, d7 \n\t" + "vmax.s8 d15, d15, d14 \n\t" + "vmax.s8 d15, d15, d13 \n\t" + "vmov.s8 d0, d1 \n\t" + "vmov.s8 d2, d3 \n\t" + "vmov.s8 d4, d5 \n\t" + "vst1.8 {d15}, [%[img_out]]! \n\t" + + PADDLE_LABEL_LESS8 + ": \n\t" + "cmp %[left_w1], #0 \n\t" + "beq " PADDLE_LABEL_OVER + "f\n\t" + "vld1.8 {d1}, [%[row0]] \n\t" // d1=9-15 + "vext.8 d6, d0, d1, #1 \n\t" + "vext.8 d7, d0, d1, #2 \n\t" + "vld1.8 {d3}, [%[row1]] \n\t" + "vmax.s8 d15, d0, d6 \n\t" + "vmax.s8 d15, d15, d7 \n\t" + "vext.8 d6, d2, d3, #1 \n\t" + "vext.8 d7, d2, d3, #2 \n\t" + "vld1.8 {d5}, [%[row2]] \n\t" + "vmax.s8 d14, d2, d6 \n\t" + "vmax.s8 d14, d14, d7 \n\t" + "vext.8 d6, d4, d5, #1 \n\t" + "vext.8 d7, d4, d5, #2 \n\t" + "vmax.s8 d13, d4, d6 \n\t" + "vmax.s8 d13, d13, d7 \n\t" + "vmax.s8 d15, d15, d14 \n\t" + "vmax.s8 d15, d15, d13 \n\t" + + PADDLE_LABEL_LESS8_SAVE + ": \n\t" + "vst1.8 {d15[0]}, [%[img_out]], r0\n\t" + "add %[row0], %[row0], #1 \n\t" + "add %[row1], %[row1], #1 \n\t" + "add %[row2], %[row2], #1 \n\t" + "vext.8 d15, d15, d15, #1 \n\t" + "subs %[left_w1], #1 \n\t" + "bgt " PADDLE_LABEL_LESS8_SAVE "b \n\t" + + PADDLE_LABEL_OVER ": \n\t" + : [nw1] "+r"(nw1), [left_w1] "+r"(left_w1), [row0] "+r"(row0), + [row1] "+r"(row1), [row2] "+r"(row2), [img_out] "+r"(img_out) + : + : "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", + "d7", "d13", "d14", "d15"); +#undef PADDLE_LABEL_OVER +#undef PADDLE_LABEL_LESS8_SAVE +#undef PADDLE_LABEL_LESS8 + } +#endif // __aarch64__ +#else + int32_t left = w_out; + while (left > 0) { + const int8_t max0 = std::max(std::max(row0[0], row0[1]), row0[2]); + const int8_t max1 = std::max(std::max(row1[0], row1[1]), row1[2]); + const int8_t max2 = std::max(std::max(row2[0], row2[1]), row2[2]); + *img_out = std::max(std::max(max0, max1), max2); + row0 += 1; + row1 += 1; + row2 += 1; + img_out++; + left--; + } +#endif // __ARM_NEON + } + } + input_data += input_batch_stride; + output_data += output_batch_stride; + } +} +void Pool3x3Maxs2_int8(const Tensor *input, Tensor *output, int32_t pad_h, + int32_t pad_w) { + Tensor padded_input; + if (pad_h != 0 && pad_w != 0) { + int8_t value = -SCHAR_MAX; + make_paddings(input, &padded_input, pad_h, pad_h, pad_w, pad_w, value); + input = &padded_input; + } + const int32_t batch_size = input->dims()[0]; + const int32_t h_in = input->dims()[2]; + const int32_t w_in = input->dims()[3]; + const int32_t output_channels = output->dims()[1]; + const int32_t h_out = output->dims()[2]; + const int32_t w_out = output->dims()[3]; + const int32_t outputdata_channel_stride = h_out * w_out; + const int32_t inputdata_channel_stride = h_in * w_in; + const int32_t output_batch_stride = + output_channels * outputdata_channel_stride; + const int32_t input_batch_stride = output_channels * inputdata_channel_stride; + const int8_t *input_data = input->data(); + int8_t *output_data = output->mutable_data(); + for (int i = 0; i < batch_size; ++i) { +#pragma omp parallel for + for (int j = 0; j < output_channels; ++j) { + const int8_t *img_in = input_data + j * inputdata_channel_stride; + int8_t *img_out = output_data + j * outputdata_channel_stride; + for (int k = 0; k < h_out; ++k) { + const int8_t *row0 = img_in + 2 * k * w_in; + const int8_t *row1 = img_in + (2 * k + 1) * w_in; + const int8_t *row2 = img_in + (2 * k + 2) * w_in; +#if __ARM_NEON + int32_t nw = w_out >> 4; + int32_t left_w = w_out & 0xf; + int32_t nw1 = left_w >> 3; + int32_t left_w1 = left_w & 0x7; +#if __aarch64__ + // TODO +#else + if (nw > 0) { +#define LOOP_LABEL "1" + // result: q15 + asm volatile( + "vld2.8 {q0, q1}, [%[row0]]! \n\t" // q0=0-30, q1=1-31 + "vld2.8 {q2, q3}, [%[row1]]! \n\t" + "vld2.8 {q4, q5}, [%[row2]]! \n\t" + + LOOP_LABEL + ": \n\t" + "vmax.s8 q15, q0, q1 \n\t" + "vld2.8 {q6, q7}, [%[row0]]! \n\t" // q0=32-62, q1=33-63 + "vmax.s8 q14, q2, q3 \n\t" + "vmax.s8 q13, q4, q5 \n\t" + "vld2.8 {q8, q9}, [%[row1]]! \n\t" + "vext.8 q0, q0, q6, #1 \n\t" + "vmax.s8 q15, q15, q0 \n\t" + "vld2.8 {q10, q11}, [%[row2]]! \n\t" + "vext.8 q2, q2, q8, #1 \n\t" + "vmax.s8 q14, q14, q2 \n\t" + "vext.8 q4, q4, q10, #1 \n\t" + "vmax.s8 q13, q13, q4 \n\t" + "vmax.s8 q15, q15, q14 \n\t" + "vmax.s8 q15, q15, q13 \n\t" + "vmov.s8 q0, q6 \n\t" + "vmov.s8 q1, q7 \n\t" + "vmov.s8 q2, q8 \n\t" + "vmov.s8 q3, q9 \n\t" + "vmov.s8 q4, q10 \n\t" + "vmov.s8 q5, q11 \n\t" + "vst1.8 {q15}, [%[img_out]]! \n\t" + "subs %[nw], #1 \n\t" + "bne " LOOP_LABEL + "b \n\t" + "sub %[row0], #32 \n\t" + "sub %[row1], #32 \n\t" + "sub %[row2], #32 \n\t" + : [nw] "+r"(nw), [row0] "+r"(row0), [row1] "+r"(row1), + [row2] "+r"(row2), [img_out] "+r"(img_out) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q13", "q14", "q15"); +#undef LOOP_LABEL + } + if (nw1 > 0 || left_w1 > 0) { +#define PADDLE_LABEL_LESS8 "1" +#define PADDLE_LABEL_LESS8_SAVE "2" +#define PADDLE_LABEL_OVER "3" + // result: d15 + asm volatile( + "vld2.8 {d0, d1}, [%[row0]]! \n\t" // d0=0-14, d1=1-15 + "vld2.8 {d2, d3}, [%[row1]]! \n\t" + "vld2.8 {d4, d5}, [%[row2]]! \n\t" + "mov r0, #1 \n\t" + "cmp %[nw1], #0 \n\t" + "beq " PADDLE_LABEL_LESS8 + "f\n\t" + "vmax.s8 d15, d0, d1 \n\t" + "vld2.8 {d6, d7}, [%[row0]]! \n\t" // d0=32-62, d1=33-63 + "vmax.s8 d14, d2, d3 \n\t" + "vmax.s8 d13, d4, d5 \n\t" + "vld2.8 {d8, d9}, [%[row1]]! \n\t" + "vext.8 d0, d0, d6, #1 \n\t" + "vmax.s8 d15, d15, d0 \n\t" + "vld2.8 {d10, d11}, [%[row2]]! \n\t" + "vext.8 d2, d2, d8, #1 \n\t" + "vmax.s8 d14, d14, d2 \n\t" + "vext.8 d4, d4, d10, #1 \n\t" + "vmax.s8 d13, d13, d4 \n\t" + "vmax.s8 d15, d15, d14 \n\t" + "vmax.s8 d15, d15, d13 \n\t" + "vmov.s8 d0, d6 \n\t" + "vmov.s8 d1, d7 \n\t" + "vmov.s8 d2, d8 \n\t" + "vmov.s8 d3, d9 \n\t" + "vmov.s8 d4, d10 \n\t" + "vmov.s8 d5, d11 \n\t" + "vst1.8 {d15}, [%[img_out]]! \n\t" + + PADDLE_LABEL_LESS8 + ": \n\t" + "cmp %[left_w1], #0 \n\t" + "beq " PADDLE_LABEL_OVER + "f\n\t" + "vmax.s8 d15, d0, d1 \n\t" + "vld2.8 {d6, d7}, [%[row0]] \n\t" // d0=32-62, d1=33-63 + "vmax.s8 d14, d2, d3 \n\t" + "vmax.s8 d13, d4, d5 \n\t" + "vld2.8 {d8, d9}, [%[row1]] \n\t" + "vext.8 d0, d0, d6, #1 \n\t" + "vmax.s8 d15, d15, d0 \n\t" + "vld2.8 {d10, d11}, [%[row2]] \n\t" + "vext.8 d2, d2, d8, #1 \n\t" + "vmax.s8 d14, d14, d2 \n\t" + "vext.8 d4, d4, d10, #1 \n\t" + "vmax.s8 d13, d13, d4 \n\t" + "vmax.s8 d15, d15, d14 \n\t" + "vmax.s8 d15, d15, d13 \n\t" + + PADDLE_LABEL_LESS8_SAVE + ": \n\t" + "vst1.8 {d15[0]}, [%[img_out]], r0\n\t" + "add %[row0], %[row0], #2 \n\t" + "add %[row1], %[row1], #2 \n\t" + "add %[row2], %[row2], #2 \n\t" + "vext.8 d15, d15, d15, #1 \n\t" + "subs %[left_w1], #1 \n\t" + "bgt " PADDLE_LABEL_LESS8_SAVE "b \n\t" + + PADDLE_LABEL_OVER ": \n\t" + : [nw1] "+r"(nw1), [left_w1] "+r"(left_w1), [row0] "+r"(row0), + [row1] "+r"(row1), [row2] "+r"(row2), [img_out] "+r"(img_out) + : + : "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", + "d7", "d8", "d9", "d10", "d11", "d13", "d14", "d15"); +#undef PADDLE_LABEL_OVER +#undef PADDLE_LABEL_LESS8_SAVE +#undef PADDLE_LABEL_LESS8 + } +#endif // __aarch64__ +#else + int32_t left = w_out; + while (left > 0) { + const int8_t max0 = std::max(std::max(row0[0], row0[1]), row0[2]); + const int8_t max1 = std::max(std::max(row1[0], row1[1]), row1[2]); + const int8_t max2 = std::max(std::max(row2[0], row2[1]), row2[2]); + *img_out = std::max(std::max(max0, max1), max2); + row0 += 2; + row1 += 2; + row2 += 2; + img_out++; + left--; + } +#endif // __ARM_NEON + } + } + input_data += input_batch_stride; + output_data += output_batch_stride; + } +} +void Pool3x3Max_int8(const vector &strides, const vector &paddings, + const Tensor *input, Tensor *output) { + const int batch_size = input->dims()[0]; + const int input_height = input->dims()[2]; + const int input_width = input->dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + // const int _kernel_size = 3; + const int stride = strides[0]; + // const int stride_width = strides[1]; + const int padding = paddings[0]; + // const int padding_width = paddings[1]; + const int8_t negative_max = -SCHAR_MAX; + const int input_channel_stride = input_height * input_width; + const int output_channel_stride = output_height * output_width; + const int8_t *input_data = input->data(); + int8_t *output_data = output->mutable_data(); + const int input_batch_stride = output_channels * input_channel_stride; + const int output_batch_stride = output_channels * output_channel_stride; + for (int i = 0; i < batch_size; ++i) { +#pragma omp parallel for + for (int c = 0; c < output_channels; ++c) { + const int8_t *input_seg = input_data + c * input_channel_stride; + int8_t *output_seg = output_data + c * output_channel_stride; + for (int ph = 0; ph < output_height; ph++) { + int hstart = ph * stride - padding; + int hend = min(hstart + 3, input_height); + hstart = max(hstart, 0); + for (int pw = 0; pw < output_width; pw++) { + int wstart = pw * stride - padding; + int wend = min(wstart + 3, input_width); + wstart = max(wstart, 0); + const int8_t *pos1 = input_seg + hstart * input_width + wstart; + const int8_t *pos2 = input_seg + (hstart + 1) * input_width + wstart; + const int8_t *pos3 = input_seg + (hstart + 2) * input_width + wstart; + int8_t *output_ptr = output_seg + ph * output_width + pw; + if (hend - hstart != 3 || wend - wstart != 3) { + int8_t max_value = -SCHAR_MAX; + for (int h = hstart; h < hend; h++) { + for (int w = wstart; w < wend; w++) { + int8_t value = input_seg[h * input_width + w]; + if (value > max_value) { + max_value = value; + } + } + } + output_seg[ph * output_width + pw] = max_value; + } else { +#if __ARM_NEON +#if __aarch64__ + // TODO +#else + asm volatile( + "vld1.8 {d0}, [%[pos1]] \n\t" + "vld1.8 {d1}, [%[pos2]] \n\t" + "vld1.8 {d2}, [%[pos3]] \n\t" + "vmax.s8 d3, d0, d1 \n\t" + "vmax.s8 d4, d2, d3 \n\t" + "vmov.s8 d4[3], %[negative_max] \n\t" + "vpmax.s8 d5, d4, d4 \n\t" + "vpmax.s8 d6, d5, d5 \n\t" + "vst1.8 {d6[0]},[%[output_ptr]] \n\t" + : + : [pos1] "r"(pos1), [pos2] "r"(pos2), [pos3] "r"(pos3), + [output_ptr] "r"(output_ptr), [negative_max] "r"(negative_max) + : "memory", "q0", "q1", "q2", "q3"); +#endif +#else + const int8_t max0 = std::max(std::max(pos1[0], pos1[1]), pos1[2]); + const int8_t max1 = std::max(std::max(pos2[0], pos2[1]), pos2[2]); + const int8_t max2 = std::max(std::max(pos3[0], pos3[1]), pos3[2]); + *output_ptr = std::max(std::max(max0, max1), max2); +#endif // __ARM_NEON + } + } + } + } + input_data += input_batch_stride; + output_data += output_batch_stride; + } +} +} // namespace math +} // namespace operators +} // namespace paddle_mobile +#endif diff --git a/src/operators/math/pooling.cpp b/src/operators/math/pooling.cpp index f5bcdf7fdb6b9245eda7d3557b293395bce23b24..17df4a26aa36509bf7d1253b8bc67f83d96b1aac 100644 --- a/src/operators/math/pooling.cpp +++ b/src/operators/math/pooling.cpp @@ -70,15 +70,15 @@ class PoolFunctor { int wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); - T ele = pool_process.initial(); + auto ele = pool_process.initial(); for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { pool_process.compute(input_data[h * input_width + w], &ele); } } int pool_size = (hend - hstart) * (wend - wstart); - pool_process.finalize(static_cast(pool_size), &ele); - output_data[ph * output_width + pw] = ele; + pool_process.finalize(static_cast(pool_size), &ele); + output_data[ph * output_width + pw] = static_cast(ele); } } input_data += input_stride; @@ -88,8 +88,10 @@ class PoolFunctor { } }; -template class PoolFunctor, float>; +template class PoolFunctor, float>; template class PoolFunctor, float>; +template class PoolFunctor, int8_t>; +template class PoolFunctor, int8_t>; } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/pooling.h b/src/operators/math/pooling.h index 3ca868fa4de4b9fefdcd8c18c0d7107cc9f60b4f..4d94550cc373055b75167fbd473353d71a88d2ce 100644 --- a/src/operators/math/pooling.h +++ b/src/operators/math/pooling.h @@ -16,6 +16,8 @@ limitations under the License. */ #pragma once +#include +#include #include "common/log.h" #include "framework/tensor.h" #include "pool_2x2.h" @@ -37,24 +39,42 @@ namespace math { * in pool pooling, and finally takes the average. * MaxPoolGrad and AvgPoolGrad are gradient operations respectively. */ -template +template class MaxPool { public: - inline T initial() { return static_cast(-FLT_MAX); } + inline T initial() { + if (typeid(T) == typeid(int8_t)) { + return static_cast(-SCHAR_MAX); + } + return static_cast(-FLT_MAX); + } inline void compute(const T &x, T *y) { *y = *y > x ? *y : x; } inline void finalize(const T &pool_field, T *y) {} }; -template +template class AvgPool { public: - inline T initial() { return static_cast(0); } - - inline void compute(const T &x, T *y) { *y += x; } - - inline void finalize(const T &pool_field, T *y) { *y /= pool_field; } + inline Otype initial() { return static_cast(0); } + + inline void compute(const Itype &x, Otype *y) { *y += x; } + + inline void finalize(const float &pool_field, Otype *y) { + if (typeid(Itype) == typeid(int8_t)) { + float tmp = *y / pool_field; + if (tmp > SCHAR_MAX) { + *y = SCHAR_MAX; + } else if (tmp < -SCHAR_MAX) { + *y = -SCHAR_MAX; + } else { + *y = static_cast(std::round(tmp)); + } + } else { + *y /= pool_field; + } + } }; template diff --git a/src/operators/op_param.h b/src/operators/op_param.h index c4f5b180b832f320ac841f593ff76076b963f55d..381b66199892df9f24eca63470314e7652f5a72a 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -439,7 +439,7 @@ class ConvParam : public OpParam { #endif - private: + protected: RType *input_; RType *output_; RType *filter_; @@ -1707,7 +1707,19 @@ class FusionConvAddReluParam : public FusionConvAddParam { FusionConvAddReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) - : FusionConvAddParam(inputs, outputs, attrs, scope) {} + : FusionConvAddParam(inputs, outputs, attrs, scope) { +#ifdef FUSION_CONVADDRELU_INT8_OP + scale_ = OpParam::InputScaleFrom(inputs, scope); +#endif + } +#ifdef FUSION_CONVADDRELU_INT8_OP + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + const RType *InputScale() const { return scale_; } + + protected: + RType *scale_; +#endif }; #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bfd125ce5b75091cfac1a2a4e2f2f025da0178dc..3d202e2bd1f89894ed2c35abded57c42cc2ec9b9 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -269,8 +269,8 @@ if (NOT FOUND_MATCH) #gen test - ADD_EXECUTABLE(test-pool operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-pool paddle-mobile) + ADD_EXECUTABLE(test-pool-op operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-pool-op paddle-mobile) #gen test ADD_EXECUTABLE(test-softmax operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h) @@ -324,6 +324,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-conv-add-relu-op operators/test_conv_add_relu_op.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-conv-add-relu-op paddle-mobile) + # gen test + ADD_EXECUTABLE(test-conv-add-relu-int8-op operators/test_fusion_conv_add_relu_int8_op.cpp test_helper.h test_include.h) + target_link_libraries(test-conv-add-relu-int8-op paddle-mobile) + # gen test ADD_EXECUTABLE(test-conv-add-bn-relu-op operators/test_fusion_conv_add_bn_relu_op.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-conv-add-bn-relu-op paddle-mobile) diff --git a/test/common/test_gemm_int8_accuracy.cpp b/test/common/test_gemm_int8_accuracy.cpp index 87f8d945648577ef1414417b57f4013d288dc043..a1920ba2bbfd6bf50357fcf05be0cf64dfc9d1fb 100644 --- a/test/common/test_gemm_int8_accuracy.cpp +++ b/test/common/test_gemm_int8_accuracy.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include #include #include +#include #include #include "../test_helper.h" #include "common/log.h" @@ -54,6 +55,37 @@ void print_matirx(int m, int n, int ldc, int8_t *c) { std::cout << std::endl; } +int32_t qadd_int32(int32_t l, int32_t r) { + int64_t res = static_cast(l) + static_cast(r); + if (res > std::numeric_limits::max()) + return std::numeric_limits::max(); + else if (res < std::numeric_limits::min()) + return std::numeric_limits::min(); + else + return static_cast(res); +} + +// round to zero +float round2zero(float v) { + float res; + if (v > 0) + res = std::floor(v); + else if (v < 0) + res = std::ceil(v); + return res; +} + +int8_t qscale_int32(int32_t v, float scale) { + float res = static_cast(v) * scale; + res = round2zero(res); + if (res > 127) + return static_cast(127); + else if (res < -127) + return static_cast(-127); + else + return static_cast(res); +} + int do_sgemm(int m, int n, int k, bool relu, int pr) { int lda = k; int ldb = n; @@ -126,10 +158,97 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) { return 0; } +int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr) { + int lda = k; + int ldb = n; + int ldc = n; + float scale = 0.00628f; + default_random_engine e; + uniform_int_distribution pixel(-127, 127); + int8_t *a = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k)); + int8_t *b = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n)); + int8_t *c = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * m * n)); + int8_t *c1 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * m * n)); + + int32_t *bias = + static_cast(paddle_mobile::memory::Alloc(sizeof(int32_t) * m)); + + for (int i = 0; i < m * k; ++i) { + a[i] = pixel(e); + } + for (int i = 0; i < k * n; ++i) { + b[i] = pixel(e); + } + for (int i = 0; i < m; ++i) { + bias[i] = static_cast(pixel(e)); + } + for (int i = 0; i < m; ++i) { + int32_t bias_v = bias[i]; + for (int j = 0; j < n; ++j) { + int32_t r = 0; + for (int p = 0; p < k; p++) { + r += static_cast(a(i, p)) * static_cast(b(p, j)); + } + r = qadd_int32(r, bias_v); + if (relu) r = std::max(0, r); + c1(i, j) = qscale_int32(r, scale); + } + } + + paddle_mobile::operators::math::Gemm gemm; +#ifdef _OPENMP + gemm.Sgemm_omp(m, n, k, scale, a, lda, b, ldb, static_cast(0), c, ldc, + relu, bias); +#else + gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast(0), c, ldc, + relu, bias); +#endif + int eq = 0; + int neq = 0; + for (int i = 0; i < m * n; ++i) { + if (c[i] == c1[i]) { + ++eq; + } else { + ++neq; + } + } + + if (pr > 0) { + std::cout << "A:" << std::endl; + print_matirx(m, k, lda, a); + std::cout << "B:" << std::endl; + print_matirx(k, n, ldb, b); + std::cout << "Bias:" << std::endl; + print_matirx(m, 1, 1, bias); + std::cout << "C:" << std::endl; + print_matirx(m, n, ldc, c); + std::cout << "C1:" << std::endl; + print_matirx(m, n, ldc, c1); + } + + std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu + << " eq=" << eq << " neq=" << neq << std::endl; + + paddle_mobile::memory::Free(a); + paddle_mobile::memory::Free(b); + paddle_mobile::memory::Free(c); + paddle_mobile::memory::Free(c1); + paddle_mobile::memory::Free(bias); + + return 0; +} + int main() { #ifdef _OPENMP - omp_set_num_threads(8); + omp_set_num_threads(4); #endif + std::cout << "\n\n******************************************************\n\n" + << std::endl; + std::cout << "Test gemm without bias:" << std::endl; do_sgemm(9, 9, 9, false, 1); do_sgemm(10, 6, 12, false, 0); do_sgemm(512, 256, 384, false, 0); @@ -140,5 +259,31 @@ int main() { do_sgemm(333, 797, 939, false, 0); do_sgemm(1024, 1024, 1024, false, 0); + std::cout << "\n\n******************************************************\n\n" + << std::endl; + std::cout << "Test gemm with bias:" << std::endl; + do_sgemm_with_bias(9, 9, 9, false, 1); + do_sgemm_with_bias(10, 6, 12, false, 0); + do_sgemm_with_bias(512, 256, 384, false, 0); + do_sgemm_with_bias(1366, 768, 256, false, 0); + do_sgemm_with_bias(1255, 755, 333, false, 0); + do_sgemm_with_bias(599, 1133, 393, false, 0); + do_sgemm_with_bias(777, 555, 999, false, 0); + do_sgemm_with_bias(333, 797, 939, false, 0); + do_sgemm_with_bias(1024, 1024, 1024, false, 0); + + std::cout << "\n\n******************************************************\n\n" + << std::endl; + std::cout << "Test gemm with relu and bias:" << std::endl; + do_sgemm_with_bias(9, 9, 9, true, 1); + do_sgemm_with_bias(10, 6, 12, true, 0); + do_sgemm_with_bias(512, 256, 384, true, 0); + do_sgemm_with_bias(1366, 768, 256, true, 0); + do_sgemm_with_bias(1255, 755, 333, true, 0); + do_sgemm_with_bias(599, 1133, 393, true, 0); + do_sgemm_with_bias(777, 555, 999, true, 0); + do_sgemm_with_bias(333, 797, 939, true, 0); + do_sgemm_with_bias(1024, 1024, 1024, true, 0); + return 0; } diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 14da4ba284b5ac7b0660bd15de871fdf5ed04cdd..f25a290aef6e228aff0a84d2640486235e0116bf 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -28,7 +28,7 @@ limitations under the License. */ int main() { paddle_mobile::PaddleMobile paddle_mobile; - paddle_mobile.SetThreadNum(8); + paddle_mobile.SetThreadNum(4); Tensor aa, bb, cc; auto aaptr = aa.mutable_data({m, k}); auto bbptr = bb.mutable_data({k, n}); @@ -44,10 +44,12 @@ int main() { ccptr[i] = 2; } - Tensor aa_int8, bb_int8, cc_int8; + Tensor aa_int8, bb_int8, cc_int32, cc_int8; auto aaptr_int8 = aa_int8.mutable_data({m, k}); auto bbptr_int8 = bb_int8.mutable_data({k, n}); - auto ccptr_int8 = cc_int8.mutable_data({m, n}); + auto ccptr_int32 = cc_int32.mutable_data({m, n}); + auto ccptr_int8 = cc_int8.mutable_data({m, n}); + int32_t* bias_data = new int32_t[m]; for (int i = 0; i < m * k; ++i) { aaptr_int8[i] = static_cast(2); @@ -56,7 +58,11 @@ int main() { bbptr_int8[i] = static_cast(2); } for (int i = 0; i < m * n; ++i) { - ccptr_int8[i] = static_cast(2); + ccptr_int32[i] = static_cast(2); + } + + for (int i = 0; i < m; ++i) { + bias_data[i] = 2; } // float @@ -76,22 +82,41 @@ int main() { auto time2 = time(); std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; - // int8_t + // int8_t without bias // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( - aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, - static_cast(0), false, nullptr); + paddle_mobile::operators::math::matmul( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, + static_cast(0)); } auto time3 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( - aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, - static_cast(0), false, nullptr); + paddle_mobile::operators::math::matmul( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, + static_cast(0)); } auto time4 = time(); std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; + // int8_t with bias&relu + // warm-up 10 times + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, + static_cast(0), true, bias_data); + } + auto time5 = time(); + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, + static_cast(0), true, bias_data); + } + auto time6 = time(); + std::cout << "int8_t gemm_with_bias_relu cost :" + << time_diff(time5, time6) / 10 << "ms\n"; + + delete[] bias_data; + return 0; } diff --git a/test/operators/test_fusion_conv_add_relu_int8_op.cpp b/test/operators/test_fusion_conv_add_relu_int8_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..42c68e5d04c03c143517a917e620d40636c382ec --- /dev/null +++ b/test/operators/test_fusion_conv_add_relu_int8_op.cpp @@ -0,0 +1,360 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDRELU_INT8_OP + +#include +#include +#include "../test_helper.h" +#include "../test_include.h" +#include "operators/fusion_conv_add_relu_int8_op.h" + +namespace paddle_mobile { +int32_t qadd_int32(int32_t l, int32_t r) { + int64_t res = static_cast(l) + static_cast(r); + if (res > std::numeric_limits::max()) + return std::numeric_limits::max(); + else if (res < std::numeric_limits::min()) + return std::numeric_limits::min(); + else + return static_cast(res); +} + +// round to zero +float round2zero(float v) { + float res; + if (v > 0) + res = std::floor(v); + else if (v < 0) + res = std::ceil(v); + return res; +} + +int8_t qscale_int32(int32_t v, float scale) { + float res = static_cast(v) * scale; + res = round2zero(res); + if (res > 127) + return static_cast(127); + else if (res < -127) + return static_cast(-127); + else + return static_cast(res); +} + +// Reference convolution from Caffe for checking results. +// accumulate through explicit loops over input, output, and filters. +template +void conv2d(const framework::Tensor *input, const framework::Tensor *filter, + const framework::Tensor *bias, const framework::AttributeMap &attrs, + framework::Tensor *output, float scale) { + framework::AttrReader attr_reader(attrs); + std::vector paddings = attr_reader.Get>("paddings"); + std::vector strides = attr_reader.Get>("strides"); + std::vector dilations = attr_reader.Get>("dilations"); + int groups = attr_reader.Get("groups"); + int kernel_h = filter->dims()[2]; + int kernel_w = filter->dims()[3]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int dilation_h = dilations[0]; + int dilation_w = dilations[1]; + auto in_shape = input->dims(); + auto out_shape = output->dims(); + + const bool has_depth = 0; + int kernel_d, pad_d, stride_d, dilation_d; + if (has_depth) { + kernel_d = kernel_h; + stride_d = stride_h; + pad_d = pad_h; + dilation_d = dilation_h; + } else { + kernel_d = stride_d = dilation_d = 1; + pad_d = 0; + } + // Groups + int o_g = out_shape[1] / groups; + int k_g = in_shape[1] / groups; + int o_head, k_head; + // Convolution + vector weight_offset(4 + has_depth); + vector in_offset(4 + has_depth); + vector out_offset(4 + has_depth); + auto offset = [](const framework::Tensor *input, const vector &indics) { + framework::DDim shape = input->dims(); + size_t count = 0; + for (int i = 0; i < indics.size(); ++i) { + count *= shape[i]; + count += indics[i]; + } + return count; + }; + + const T *in_data = input->data(); + const T *w_data = filter->data(); + framework::Tensor output_32; + int32_t *out_data_32 = output_32.mutable_data(out_shape); + memset(out_data_32, 0, output_32.numel() * sizeof(int32_t)); + for (int n = 0; n < out_shape[0]; n++) { + for (int g = 0; g < groups; g++) { + o_head = o_g * g; + k_head = k_g * g; + for (int o = 0; o < o_g; o++) { + for (int k = 0; k < k_g; k++) { + for (int z = 0; z < (has_depth ? out_shape[2] : 1); z++) { + for (int y = 0; y < out_shape[2 + has_depth]; y++) { + for (int x = 0; x < out_shape[3 + has_depth]; x++) { + for (int r = 0; r < kernel_d; r++) { + for (int p = 0; p < kernel_h; p++) { + for (int q = 0; q < kernel_w; q++) { + int in_z = z * stride_d - pad_d + r * dilation_d; + int in_y = y * stride_h - pad_h + p * dilation_h; + int in_x = x * stride_w - pad_w + q * dilation_w; + if (in_z >= 0 && in_z < (has_depth ? in_shape[2] : 1) && + in_y >= 0 && in_y < in_shape[2 + has_depth] && + in_x >= 0 && in_x < in_shape[3 + has_depth]) { + weight_offset[0] = o + o_head; + weight_offset[1] = k; + if (has_depth) { + weight_offset[2] = r; + } + weight_offset[2 + has_depth] = p; + weight_offset[3 + has_depth] = q; + in_offset[0] = n; + in_offset[1] = k + k_head; + if (has_depth) { + in_offset[2] = in_z; + } + in_offset[2 + has_depth] = in_y; + in_offset[3 + has_depth] = in_x; + out_offset[0] = n; + out_offset[1] = o + o_head; + if (has_depth) { + out_offset[2] = z; + } + out_offset[2 + has_depth] = y; + out_offset[3 + has_depth] = x; + + out_data_32[offset(output, out_offset)] += + in_data[offset(input, in_offset)] * + w_data[offset(filter, weight_offset)]; + } + } + } + } + } + } + } + } + } + } + } + + T *out_data = output->mutable_data(); + int32_t n = out_shape[0]; + int32_t c = out_shape[1]; + int32_t h = out_shape[2]; + int32_t w = out_shape[3]; + const int32_t *bias_data = bias->data(); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < c; ++j) { + int32_t bias_v = bias_data[j]; + for (int k = 0; k < h; ++k) { + for (int l = 0; l < w; ++l) { + int32_t tmp = out_data_32[i * c * h * w + j * h * w + k * w + l]; + tmp = qadd_int32(tmp, bias_v); + tmp = std::max(0, tmp); + out_data[i * c * h * w + j * h * w + k * w + l] = + qscale_int32(tmp, scale); + } + } + } + } +} + +template +int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) { + int kernel_h = Kernel; + int kernel_w = Kernel; + int pad_h = Pad; + int pad_w = Pad; + int stride_h = Stride; + int stride_w = Stride; + int dilation_h = 1; + int dilation_w = 1; + + int batch_size = 1; + int input_c = in_channels; + int input_h = in_height; + int input_w = in_width; + int output_c = out_channels; + framework::DDim input_shape = + framework::make_ddim({batch_size, input_c, input_h, input_w}); + framework::DDim filter_shape = + framework::make_ddim({output_c, input_c, kernel_h, kernel_w}); + + int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; + int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1; + int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1; + framework::DDim output_shape = framework::make_ddim( + std::vector({batch_size, output_c, output_h, output_w})); + + framework::DDim bias_shape = framework::make_ddim({output_c}); + + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["Input"] = std::vector({"input"}); + inputs["Filter"] = std::vector({"filter"}); + inputs["Scale"] = std::vector({"scale"}); + inputs["Y"] = std::vector({"bias"}); + outputs["Out"] = std::vector({"output"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, input_shape, -127, 127); + + auto filter_var = scope.get()->Var("filter"); + auto filter = filter_var->template GetMutable(); + SetupTensor(filter, filter_shape, -127, 127); + + auto scale_var = scope.get()->Var("scale"); + auto scale = scale_var->template GetMutable(); + scale->Resize(framework::make_ddim({1})); + float scale_v = 0.000828f; + scale->mutable_data()[0] = scale_v; + + auto bias_var = scope.get()->Var("bias"); + auto bias = bias_var->template GetMutable(); + SetupTensor(bias, bias_shape, -127, 127); + + auto output_var = scope.get()->Var("output"); + framework::AttributeMap attrs; + attrs["strides"].Set>(std::vector({stride_h, stride_w})); + attrs["paddings"].Set>(std::vector({pad_h, pad_w})); + attrs["dilations"].Set>( + std::vector({dilation_h, dilation_w})); + attrs["groups"].Set(1); + attrs["axis"].Set(0); + + auto *op = new operators::FusionConvAddReluInt8Op( + "fusion_conv_add_relu_int8", inputs, outputs, attrs, scope); + op->InferShape(); + op->Init(); + op->Run(); + + framework::Tensor output_cmp; + output_cmp.mutable_data(output_shape); + conv2d(input, filter, bias, attrs, &output_cmp, scale_v); + + // compare results + int eq = 0; + int neq = 0; + auto output = output_var->template Get(); + const T *output_data = output->data(); + T *output_cmp_data = output_cmp.data(); + for (int i = 0; i < output->numel(); ++i) { + PADDLE_MOBILE_ENFORCE( + output_data[i] == output_cmp_data[i], + "The execution of test_fusion_conv_add_relu_int8_op is failed!"); + if (output_data[i] == output_cmp_data[i]) { + ++eq; + } else { + ++neq; + } + } + std::cout << "eq = " << eq << ", neq = " << neq << std::endl; + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main(int argc, char *argv[]) { + if (argc < 5) { + LOG(paddle_mobile::kLOG_INFO) + << "Usage:\n" + << " ./test-conv-add-relu-int8-op in_channels in_height in_width " + "out_channels\n" + << " params:\n" + << " -in_channels: int, input image's channels\n" + << " -in_height: int, input image's height\n" + << " -in_width: int, input image's width\n" + << " -out_channels: int, conv output channels\n"; + return 1; + } + int in_channels = atoi(argv[1]); + int in_height = atoi(argv[2]); + int in_width = atoi(argv[3]); + int out_channels = atoi(argv[4]); + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8_t, kernel=3, pad=1, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 0, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 1, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 3, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 3, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 5, stride = 3 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 3, stride = 4 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 3, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + + // kernel = 5, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + + // kernel = 5, pad = 2, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); +} + +#endif diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 262ee960e1c777d369d3b510eb31e5ed47b3493c..99a2219749c7b16a2dff6a8c78621306f0aad1e6 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "../test_helper.h" #include "../test_include.h" #include "operators/mul_op.h" @@ -79,14 +80,14 @@ int TestMulOP() { PADDLE_MOBILE_ENFORCE( output_data[i] == c[i], "output[%d] = %d, output_cmp[%d] = %d", i, static_cast(output_data[i]), i, static_cast(c[i])); - if (static_cast(output_data[i] == c[i])) { + if (output_data[i] == c[i]) { ++eq; } else { ++neq; } } - DLOG << "mnk=" << m << " " << n << " " << k << " eq=" << eq - << " neq=" << neq; + std::cout << "mnk=" << m << " " << n << " " << k << " eq=" << eq + << " neq=" << neq << std::endl; delete op; return 0; } @@ -94,7 +95,7 @@ int TestMulOP() { int main() { paddle_mobile::PaddleMobile paddle_mobile; - paddle_mobile.SetThreadNum(8); + paddle_mobile.SetThreadNum(4); paddle_mobile::TestMulOP(); paddle_mobile::TestMulOP(); return 0; diff --git a/test/operators/test_pool_op.cpp b/test/operators/test_pool_op.cpp index 09470caf82eb90df56f7aa79b6873c2a6b94fbef..5784ac065496ccfde73e516802fb2f79f622836f 100644 --- a/test/operators/test_pool_op.cpp +++ b/test/operators/test_pool_op.cpp @@ -12,30 +12,281 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "../test_include.h" +#include "operators/kernel/central-arm-func/pool_arm_func.h" #include "operators/pool_op.h" -int main() { - paddle_mobile::framework::Loader loader; - auto program = loader.Load(std::string(g_googlenet)); - if (program.originProgram == nullptr) { - DLOG << "program read file"; +namespace paddle_mobile { +static int PoolOutputSize(int input_size, int filter_size, int padding, + int stride, bool ceil_mode) { + int output_size; + if (!ceil_mode) { + output_size = (input_size - filter_size + 2 * padding) / stride + 1; + } else { + output_size = + (input_size - filter_size + 2 * padding + stride - 1) / stride + 1; } + return output_size; +} + +template +static void PoolAvgPad0(std::vector ksize, std::vector strides, + const framework::Tensor *input, + framework::Tensor *out) { + const int32_t batch_size = input->dims()[0]; + const int32_t input_c = input->dims()[1]; + const int32_t input_h = input->dims()[2]; + const int32_t input_w = input->dims()[3]; + const int32_t out_c = out->dims()[1]; + const int32_t out_h = out->dims()[2]; + const int32_t out_w = out->dims()[3]; + const int32_t kernel_h = ksize[0]; + const int32_t kernel_w = ksize[1]; + const int32_t stride_h = strides[0]; + const int32_t stride_w = strides[1]; + const int32_t inputdata_channel_stride = input_h * input_w; + const int32_t input_batch_stride = input_c * inputdata_channel_stride; + const int32_t outputdata_channel_stride = out_h * out_w; + const int32_t output_batch_stride = out_c * outputdata_channel_stride; + T *out_data = out->mutable_data(); + const T *input_data = input->data(); + const T **rows = new const T *[kernel_h]; + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < out_c; ++j) { + const T *img_in = input_data + j * inputdata_channel_stride; + T *img_out = out_data + j * outputdata_channel_stride; + for (int k = 0; k < out_h; ++k) { + for (int m = 0; m < kernel_h; ++m) { + rows[m] = img_in + (stride_h * k + m) * input_w; + } + int32_t left = out_w; + while (left > 0) { + float tmp = 0; + for (int m = 0; m < kernel_h; ++m) { + for (int l = 0; l < kernel_w; ++l) { + tmp += rows[m][l]; + } + } + if (typeid(T) == typeid(int8_t)) { + tmp = tmp / (kernel_h * kernel_w); + if (tmp < -127) { + *img_out = -127; + } else if (tmp > 127) { + *img_out = 127; + } else { + *img_out = static_cast(std::round(tmp)); + } + } else { + *img_out = static_cast(tmp / (kernel_h * kernel_w)); + } + for (int m = 0; m < kernel_h; ++m) { + rows[m] += stride_w; + } + img_out++; + left--; + } + } + } + input_data += input_batch_stride; + out_data += output_batch_stride; + } + delete[] rows; +} + +template +int TestPoolOp(int in_channels, int in_height, int in_width) { + int kernel_h = Kernel; + int kernel_w = Kernel; + int pad_h = Pad; + int pad_w = Pad; + int stride_h = Stride; + int stride_w = Stride; + bool ceil_mode = CeilMode != 0; + std::string pooling_type = (PoolType == 0 ? "max" : "avg"); + + int batch_size = 1; + int input_c = in_channels; + int input_h = in_height; + int input_w = in_width; + + framework::DDim input_shape = + framework::make_ddim({batch_size, input_c, input_h, input_w}); + + std::vector output_shape_v({batch_size, input_c}); + output_shape_v.push_back( + PoolOutputSize(input_h, kernel_h, pad_h, stride_h, ceil_mode)); + output_shape_v.push_back( + PoolOutputSize(input_w, kernel_w, pad_w, stride_w, ceil_mode)); + + framework::DDim output_shape = framework::make_ddim(output_shape_v); - Executor4Test> - executor(program, "pool2d"); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input"}); + outputs["Out"] = std::vector({"output"}); - paddle_mobile::framework::Tensor input; - SetupTensor(&input, {1, 64, 112, 112}, static_cast(0), - static_cast(1)); - auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 56, 56}); - auto output = - executor.Predict(input, "conv2d_0.tmp_1", "pool2d_0.tmp_0", out_ddim); + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, input_shape, -127, 127); - float *output_ptr = output->data(); - for (int j = 0; j < output->numel(); ++j) { - DLOG << " value of output: " << output_ptr[j]; + auto output_var = scope.get()->Var("output"); + framework::AttributeMap attrs; + attrs["pooling_type"].SetString(pooling_type); + attrs["ksize"].Set>(std::vector({kernel_h, kernel_w})); + attrs["strides"].Set>(std::vector({stride_h, stride_w})); + attrs["paddings"].Set>(std::vector({pad_h, pad_w})); + attrs["ceil_mode"].Set(false); + attrs["global_pooling"].Set(false); + + auto *op = new operators::PoolOp("pool2d", inputs, outputs, attrs, + scope); + op->InferShape(); + op->Init(); + op->Run(); + + framework::Tensor output_cmp; + output_cmp.mutable_data(output_shape); + if (pooling_type == "avg" && pad_h == 0 && pad_h == pad_w) { + PoolAvgPad0(std::vector{kernel_h, kernel_w}, + std::vector{stride_h, stride_w}, input, &output_cmp); + } else { + if (typeid(T) == typeid(int8_t)) { + operators::PoolBasic( + pooling_type, std::vector{kernel_h, kernel_w}, + std::vector{stride_h, stride_w}, std::vector{pad_h, pad_w}, + input, &output_cmp); + } else { + operators::PoolBasic( + pooling_type, std::vector{kernel_h, kernel_w}, + std::vector{stride_h, stride_w}, std::vector{pad_h, pad_w}, + input, &output_cmp); + } + } + + // compare results + int eq = 0; + int neq = 0; + auto output = output_var->template Get(); + const T *output_data = output->data(); + T *output_cmp_data = output_cmp.data(); + for (int i = 0; i < output->numel(); ++i) { + PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], + "The execution of test_pool_op is failed!"); + if (output_data[i] == output_cmp_data[i]) { + ++eq; + } else { + ++neq; + } } + std::cout << "eq = " << eq << ", neq = " << neq << std::endl; + delete op; + return 0; } +} // namespace paddle_mobile + +int main(int argc, char *argv[]) { + if (argc < 4) { + LOG(paddle_mobile::kLOG_INFO) + << "Usage:\n" + << " ./test-pool-op in_channels in_height in_width \n" + << " params:\n" + << " -in_channels: int, input image's channels\n" + << " -in_height: int, input image's height\n" + << " -in_width: int, input image's width\n"; + return 1; + } + int in_channels = atoi(argv[1]); + int in_height = atoi(argv[2]); + int in_width = atoi(argv[3]); +#if __ARM_NEON + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) + << "float, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 3, pad = 0, stride = 2 + LOG(paddle_mobile::kLOG_INFO) + << "float, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); +#endif + // kernel = 3, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=1"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 3, pad = 2, stride = 1 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=1"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 3, pad = 0, stride = 2 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 3, pad = 1, stride = 2 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=2"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 3, pad = 0, stride = 2 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=2"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 3, pad = 3, stride = 3 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=3, stride=3"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 7, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 7, pad = 0, stride = 2 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=2"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 7, pad = 0, stride = 3 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=3"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 3, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=1"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 3, pad = 0, stride = 3 + LOG(paddle_mobile::kLOG_INFO) + << "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=3"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 7, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) + << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 7, pad = 0, stride = 4 + LOG(paddle_mobile::kLOG_INFO) + << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=4"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); + // kernel = 5, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) + << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1"; + paddle_mobile::TestPoolOp(in_channels, in_height, + in_width); +} diff --git a/tools/op.cmake b/tools/op.cmake index e2254c3261d53d142e77f09c001d9cbebb5f85ff..52d745565cedc81a0eeac49dda56dab08ffa1dc0 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -213,6 +213,7 @@ if(NOT FOUND_MATCH) set(FUSION_CONVADD_OP ON) set(FUSION_CONVADDPRELU_OP ON) set(FUSION_CONVADDRELU_OP ON) + set(FUSION_CONVADDRELU_INT8_OP ON) set(FUSION_FC_OP ON) set(LRN_OP ON) set(MUL_OP ON) @@ -309,6 +310,9 @@ endif() if (FUSION_CONVADDRELU_OP) add_definitions(-DFUSION_CONVADDRELU_OP) endif() +if (FUSION_CONVADDRELU_INT8_OP) + add_definitions(-DFUSION_CONVADDRELU_INT8_OP) +endif() if (FUSION_CONVADDPRELU_OP) add_definitions(-DFUSION_CONVADDPRELU_OP) endif()