diff --git a/CMakeLists.txt b/CMakeLists.txt index f5d68712a64b5a47657a7af9c0e6b47604893e23..5f5d094bed1daf7bf34f744dcc1eec8cf59d3af2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,7 +34,7 @@ endif() if(DEBUGING) message(STATUS "debugging mode") - add_definitions(-DPADDLE_MOBILE_DEBUG) +# add_definitions(-DPADDLE_MOBILE_DEBUG) else() endif() diff --git a/src/common/types.cpp b/src/common/types.cpp index fcffcae1ab0322b839ace5447885e87fdb78fbf8..ef2d4ed1fc68bcb96fd1cdea10b654ba3bb05ffd 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -114,7 +114,7 @@ std::unordered_map< {G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}}, {G_OP_TYPE_FILL_CONSTANT, {{}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}}, - {G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8, {{"Input"}, {"Output"}}}, + {G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_PRELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_IM2SEQUENCE, {{"X"}, {"Out"}}}, diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 4b50f15a868e3bdbb8434af0cc0d49a6cb54c6a5..cb7051468715179e1d9a5ead407941a20d9cb87a 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -153,7 +153,8 @@ double PaddleMobile::GetPredictTime() { paddle_mobile::operators::math::Gemm gemm; auto time1 = paddle_mobile::time(); gemm.Sgemm(m, n, k, static_cast(1), a, lda, b, ldb, - static_cast(0), c, ldc, false, nullptr); + static_cast(0), c, ldc, false, + static_cast(nullptr)); auto time2 = paddle_mobile::time(); double cost = paddle_mobile::time_diff(time1, time2); paddle_mobile::memory::Free(a); diff --git a/src/operators/fusion_conv_add_relu_int8_op.h b/src/operators/fusion_conv_add_relu_int8_op.h index c9ca511eaa615ec88e95f82c38be2d6456a6ab49..5e4b4c08065de8111ae5511b5e9448bacda74c8b 100644 --- a/src/operators/fusion_conv_add_relu_int8_op.h +++ b/src/operators/fusion_conv_add_relu_int8_op.h @@ -16,28 +16,26 @@ limitations under the License. */ #pragma once #include #include "framework/operator.h" -#include "operators/kernel/conv_add_relu_int8_kernel.h" +#include "operators/kernel/conv_add_relu_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { -using std::string; template class FusionConvAddReluInt8Op : public framework::OperatorWithKernel< - DeviceType, FusionConvAddReluInt8Param, - operators::ConvAddReluInt8Kernel> { + DeviceType, FusionConvAddReluParam, + operators::ConvAddReluKernel> { public: - FusionConvAddReluInt8Op(const string &type, const VariableNameMap &inputs, + FusionConvAddReluInt8Op(const std::string &type, + const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, std::shared_ptr scope) : framework::OperatorWithKernel< - DeviceType, FusionConvAddReluInt8Param, - operators::ConvAddReluInt8Kernel>( - type, inputs, outputs, attrs, scope) {} + DeviceType, FusionConvAddReluParam, + operators::ConvAddReluKernel>(type, inputs, outputs, + attrs, scope) {} void InferShape() const override; - - protected: }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/conv_add_relu_int8_kernel.cpp b/src/operators/kernel/arm/conv_add_relu_int8_kernel.cpp deleted file mode 100644 index b73dcf0c02fa410d6de9f5deae33e54ff88fb72f..0000000000000000000000000000000000000000 --- a/src/operators/kernel/arm/conv_add_relu_int8_kernel.cpp +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDRELU_INT8_OP - -#include "operators/kernel/conv_add_relu_int8_kernel.h" -#include "operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h" - -namespace paddle_mobile { -namespace operators { - -template <> -bool ConvAddReluInt8Kernel::Init( - FusionConvAddReluInt8Param *param) { - return true; -} - -template <> -void ConvAddReluInt8Kernel::Compute( - const FusionConvAddReluInt8Param ¶m) { - ConvAddReluInt8Compute(param); -} -template class ConvAddReluInt8Kernel; - -} // namespace operators -} // namespace paddle_mobile - -#endif // FUSION_CONVADDRELU_INT8_OP diff --git a/src/operators/kernel/arm/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/conv_add_relu_kernel.cpp index 211d6d8487bfd4afc71d74e5ecbff149ad34e466..150bf1d77e33b99cbd7786f3885f2012270c0c78 100644 --- a/src/operators/kernel/arm/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/conv_add_relu_kernel.cpp @@ -28,10 +28,24 @@ bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { template <> void ConvAddReluKernel::Compute( const FusionConvAddReluParam ¶m) { - ConvAddReluCompute(param); + ConvAddReluCompute(param); } template class ConvAddReluKernel; +#ifdef FUSION_CONVADDRELU_INT8_OP +template <> +bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { + return true; +} + +template <> +void ConvAddReluKernel::Compute( + const FusionConvAddReluParam ¶m) { + ConvAddReluCompute(param); +} +template class ConvAddReluKernel; +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h index 9ea8dbf0c115e1870b19965efd2441aa940aaa9f..9e46790cfe6f8d21f6c466c64853b5efc7db927c 100644 --- a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h @@ -25,22 +25,31 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -template +template void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); Tensor bias = *param.Bias(); - int axis = param.Axis(); + int32_t axis = param.Axis(); + S *bias_data = bias.data(); Tensor *output = param.Output(); - float *biase_data = bias.data(); output->mutable_data

(); - int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); + float alpha = 1.0f; + float beta = 1.0f; - const int batch_size = static_cast(input->dims()[0]); +#ifdef FUSION_CONVADDRELU_INT8_OP + Tensor scale = *param.InputScale(); + alpha = scale.data()[0]; + beta = 0.0f; +#endif + + int32_t groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int32_t batch_size = static_cast(input->dims()[0]); std::vector filter_shape_vec(framework::vectorize(filter.dims())); @@ -62,13 +71,13 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { Tensor col; Tensor col_matrix; if (is_expand) { - col.mutable_data(col_shape); + col.mutable_data

(col_shape); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); + input->dims(), 1, static_cast(input->dims().size())); framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; @@ -78,17 +87,17 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { output->numel() / (output->dims()[0] * output->dims()[1])}; // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; + int32_t in_step = static_cast(input->dims()[1]) / groups; + int32_t out_step = static_cast(output->dims()[1]) / groups; - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; - for (int i = 0; i < batch_size; i++) { + for (int32_t i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { + for (int32_t g = 0; g < groups; g++) { Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); if (!is_expand) { @@ -98,8 +107,8 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { } else if (data_dim == 2U) { // im2col im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, &col); } else if (data_dim == 3U) { // vol2col @@ -109,9 +118,9 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(1), true, biase_data); + + math::matmul(filter_slice, false, col_matrix, false, alpha, &out_slice, + beta, true, bias_data); } } } diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h deleted file mode 100644 index b5d35f206f1ebf57bba81091c5d0e8c4ddb74457..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDRELU_INT8_OP - -#pragma once -#include -#include "operators/math/conv_func.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -template -void ConvAddReluInt8Compute(const FusionConvAddReluInt8Param ¶m) { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor bias = *param.Bias(); - Tensor scale = *param.InputScale(); - int32_t axis = param.Axis(); - Tensor *output = param.Output(); - output->mutable_data

(); - - int32_t *bias_data = bias.data(); - float scale_v = scale.data()[0]; - - int32_t groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); - - const int32_t batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = - math::IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data

(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int32_t in_step = static_cast(input->dims()[1]) / groups; - int32_t out_step = static_cast(output->dims()[1]) / groups; - - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; - - for (int32_t i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - - for (int32_t g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - - math::matmul(filter_slice, false, col_matrix, false, scale_v, &out_slice, - static_cast(0), true, bias_data); - } - } -} - -} // namespace operators -} // namespace paddle_mobile - -#endif // FUSION_CONVADDRELU_INT8_OP diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index f746eae470ede7f6cc21b8abde462eafd46ab89e..95299b0799764639bfb36721f4707b1382533bb6 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -106,16 +106,9 @@ inline void GemmConv(const ConvParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - - if (param.Input()->type() == typeid(int8_t)) { - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, static_cast(0), - false, static_cast(nullptr)); - } else { - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, static_cast(0), - false, static_cast(nullptr)); - } + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, static_cast(0), + false, static_cast(nullptr)); } } } diff --git a/src/operators/kernel/conv_add_relu_int8_kernel.h b/src/operators/kernel/conv_add_relu_int8_kernel.h deleted file mode 100644 index ecd9f3d8630057705f445dc042f0cd1f58725fa8..0000000000000000000000000000000000000000 --- a/src/operators/kernel/conv_add_relu_int8_kernel.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDRELU_INT8_OP - -#pragma once - -#include -#include "framework/ddim.h" -#include "framework/operator.h" -#include "operators/math/conv_func.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -using framework::DDim; -using framework::OpKernelBase; - -template -class ConvAddReluInt8Kernel - : public OpKernelBase> { - public: - void Compute(const FusionConvAddReluInt8Param ¶m); - bool Init(FusionConvAddReluInt8Param *param); -}; - -} // namespace operators -} // namespace paddle_mobile - -#endif // FUSION_CONVADDRELU_INT8_OP diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index c17b2a5e4df0f0ca88da79a9ce55c2ecae0316b5..297ca2538d5ee06cec3c8ed25fb4519aa5e1a827 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -2924,6 +2924,7 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, #endif // __ARM_NEON // 32位 float 矩阵乘法 +template <> void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *bias) { diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index e409fe07dc55bcf68748f0f25b3b63480d25cd56..bccddffa5649a31759ad7a7fce0fb037c526f6df 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include #include "common/log.h" +#include "memory/t_malloc.h" // 矩阵取值运算宏,假设矩阵按行存储 #define A(i, j) A[(i)*lda + (j)] @@ -163,11 +164,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *new_bias); */ - // 32位 float 矩阵乘法 - void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, int ldc, bool relu, - float *bias); - // 32位 float 矩阵乘法, 并对结果进行 batchnrom void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, @@ -201,11 +197,13 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, int32_t ldc); // 8 bits int inner product + template void InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, - const int8_t *b, float beta, int32_t *c, int32_t *C, + const int8_t *b, float beta, int32_t *c, Otype *C, int32_t ldc, bool relu); + template void InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a, - const int8_t *b, float beta, int32_t *c, int8_t *C, + const int8_t *b, float beta, int32_t *c, Otype *C, int32_t ldc, bool relu, int32_t *bias); // 8 bits int pack function @@ -229,12 +227,15 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, const int8_t *B, int32_t ldb, int8_t *buffer); // 8 bits int matrix product + template + void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const Itype *A, + int32_t lda, const Itype *B, int32_t ldb, float beta, Otype *C, + int32_t ldc, bool relu, Btype *bias); + template void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, float beta, int32_t *C, - int32_t ldc, bool relu, int32_t *bias); - void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, float beta, int8_t *C, + int32_t lda, const int8_t *B, int32_t ldb, float beta, Otype *C, int32_t ldc, bool relu, int32_t *bias); + void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, float beta, int32_t *C, int32_t ldc, bool relu, int32_t *bias); @@ -266,6 +267,71 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, int8_t *zero_int8; }; +// 8 bits int matrix product (m*k x k*n) +template +void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, + Otype *C, int32_t ldc, bool relu, int32_t *bias) { + // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) + // L2 cache is 0.5~4 Mib (Contex-A72 cluster) + int32_t L1 = 32 * 1024; + int32_t L2 = 512 * 1024; + + const int32_t k_complete = (k + 15) - ((k + 15) & 15); + KC = k_complete; + MC = L1 / (KC * sizeof(int8_t)); + NC = L2 / (KC * sizeof(int8_t)); + + // make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8 + if (MC == 0) { + MC = MR_INT8; + } else { + int32_t mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; + } + // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; + if (NC == 0) { + NC = NR_INT8; + } else { + int32_t nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; + } + // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); + packedC_int32 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); + zero_int8 = + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); + + memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); + int32_t mc, nc; + for (int32_t j = 0; j < n; j += NC) { + nc = s_min(n - j, NC); + PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8); + for (int32_t i = 0; i < m; i += MC) { + mc = s_min(m - i, MC); + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8); + if (bias == nullptr) { + InnerKernel(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int32, &C(i, j), ldc, relu); + } else { + InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int32, &C(i, j), ldc, relu, bias + i); + } + } + } + + paddle_mobile::memory::Free(packedA_int8); + paddle_mobile::memory::Free(packedB_int8); + paddle_mobile::memory::Free(packedC_int32); + paddle_mobile::memory::Free(zero_int8); +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index d0de4d6f09a45694c18c7ce15834967754ee462b..d5788eafe42c391d3f14fea8a381ad28d4010429 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -14,7 +14,6 @@ limitations under the License. */ #include #include "common/log.h" -#include "memory/t_malloc.h" #include "operators/math/gemm.h" #if __ARM_NEON #include @@ -670,6 +669,11 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, } // 8 bits int inner product +template <> +void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, int8_t *C, + int32_t ldc, bool relu) {} +template <> void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, const int8_t *b, float beta, int32_t *c, int32_t *C, int32_t ldc, bool relu) { @@ -691,6 +695,7 @@ void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, } } +template <> void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a, const int8_t *b, float beta, int32_t *c, int8_t *C, int32_t ldc, bool relu, @@ -715,6 +720,12 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, } } +template <> +void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, + const int8_t *a, const int8_t *b, float beta, + int32_t *c, int32_t *C, int32_t ldc, bool relu, + int32_t *bias) {} + // 8 bits int PackMatrixA_4r void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { @@ -1083,128 +1094,6 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, } } -// 8 bits int matrix product (m*k x k*n) -void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, float beta, - int32_t *C, int32_t ldc, bool relu, int32_t *bias) { - // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) - // L2 cache is 0.5~4 Mib (Contex-A72 cluster) - int32_t L1 = 32 * 1024; - int32_t L2 = 512 * 1024; - - const int32_t k_complete = (k + 15) - ((k + 15) & 15); - KC = k_complete; - MC = L1 / (KC * sizeof(int8_t)); - NC = L2 / (KC * sizeof(int8_t)); - - // make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8 - if (MC == 0) { - MC = MR_INT8; - } else { - int32_t mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; - } - // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; - if (NC == 0) { - NC = NR_INT8; - } else { - int32_t nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; - } - // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; - packedA_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); - packedB_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); - packedC_int32 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); - zero_int8 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); - - memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); - int32_t mc, nc; - for (int32_t j = 0; j < n; j += NC) { - nc = s_min(n - j, NC); - PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8); - for (int32_t i = 0; i < m; i += MC) { - mc = s_min(m - i, MC); - PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8); - if (bias == nullptr) { - InnerKernel(mc, nc, alpha, packedA_int8, packedB_int8, beta, - packedC_int32, &C(i, j), ldc, relu); - } - } - } - - paddle_mobile::memory::Free(packedA_int8); - paddle_mobile::memory::Free(packedB_int8); - paddle_mobile::memory::Free(packedC_int32); - paddle_mobile::memory::Free(zero_int8); -} - -// 8 bits int matrix product (m*k x k*n) -void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, float beta, - int8_t *C, int32_t ldc, bool relu, int32_t *bias) { - // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) - // L2 cache is 0.5~4 Mib (Contex-A72 cluster) - int32_t L1 = 32 * 1024; - int32_t L2 = 512 * 1024; - - const int32_t k_complete = (k + 15) - ((k + 15) & 15); - KC = k_complete; - MC = L1 / (KC * sizeof(int8_t)); - NC = L2 / (KC * sizeof(int8_t)); - - // make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8 - if (MC == 0) { - MC = MR_INT8; - } else { - int32_t mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; - } - // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; - if (NC == 0) { - NC = NR_INT8; - } else { - int32_t nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; - } - // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; - packedA_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); - packedB_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); - packedC_int32 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); - zero_int8 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); - - memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); - int32_t mc, nc; - for (int32_t j = 0; j < n; j += NC) { - nc = s_min(n - j, NC); - PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8); - for (int32_t i = 0; i < m; i += MC) { - mc = s_min(m - i, MC); - PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8); - if (bias != nullptr) { - InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, - packedC_int32, &C(i, j), ldc, relu, bias + i); - } - } - } - - paddle_mobile::memory::Free(packedA_int8); - paddle_mobile::memory::Free(packedB_int8); - paddle_mobile::memory::Free(packedC_int32); - paddle_mobile::memory::Free(zero_int8); -} - // 8 bits int write back // C = A * B void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 12c26aed3ac8685a4f8b662e3bb39ff711a7019a..289e29c382ad39006fb65b38be3bc9ebfc58fed6 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1705,36 +1705,19 @@ class FusionConvAddReluParam : public FusionConvAddParam { FusionConvAddReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) - : FusionConvAddParam(inputs, outputs, attrs, scope) {} -}; -#endif - + : FusionConvAddParam(inputs, outputs, attrs, scope) { #ifdef FUSION_CONVADDRELU_INT8_OP -template -class FusionConvAddReluInt8Param : public ConvParam { - typedef typename DtypeTensorTrait::gtype GType; - typedef typename DtypeTensorTrait::rtype RType; - - public: - FusionConvAddReluInt8Param(const VariableNameMap &inputs, - const VariableNameMap &outputs, - const AttributeMap &attrs, const Scope &scope) - : ConvParam(inputs, outputs, attrs, scope) { scale_ = OpParam::InputScaleFrom(inputs, scope); - bias_ = OpParam::InputYFrom(inputs, scope); - axis_ = OpParam::GetAttr("axis", attrs); +#endif } - +#ifdef FUSION_CONVADDRELU_INT8_OP + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; const RType *InputScale() const { return scale_; } - RType *Bias() const { return bias_; } - - const int &Axis() const { return axis_; } - protected: RType *scale_; - RType *bias_; - int axis_; +#endif }; #endif diff --git a/test/common/test_gemm_int8_accuracy.cpp b/test/common/test_gemm_int8_accuracy.cpp index 9120a9c7fac64e6ee1a1ee89691848496b936ac2..6e2d838955f33c8a7784b2df4c6f32baa3665e67 100644 --- a/test/common/test_gemm_int8_accuracy.cpp +++ b/test/common/test_gemm_int8_accuracy.cpp @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include #include +#include #include #include "../test_helper.h" #include "common/log.h" @@ -57,10 +57,10 @@ void print_matirx(int m, int n, int ldc, int8_t *c) { int32_t qadd_int32(int32_t l, int32_t r) { int64_t res = static_cast(l) + static_cast(r); - if (res > INT_MAX) - return INT_MAX; - else if (res < INT_MIN) - return INT_MIN; + if (res > std::numeric_limits::max()) + return std::numeric_limits::max(); + else if (res < std::numeric_limits::min()) + return std::numeric_limits::min(); else return static_cast(res); } diff --git a/test/operators/test_fusion_conv_add_relu_int8_op.cpp b/test/operators/test_fusion_conv_add_relu_int8_op.cpp index 4c80f9c449b61c8a9edb397c6149a4b36ac7a486..8d7067898e98982a506c7ed707800d07aa57884e 100644 --- a/test/operators/test_fusion_conv_add_relu_int8_op.cpp +++ b/test/operators/test_fusion_conv_add_relu_int8_op.cpp @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef FUSION_CONVADDRELU_INT8_OP + +#include +#include #include "../test_helper.h" #include "../test_include.h" #include "operators/fusion_conv_add_relu_int8_op.h" @@ -19,10 +23,10 @@ limitations under the License. */ namespace paddle_mobile { int32_t qadd_int32(int32_t l, int32_t r) { int64_t res = static_cast(l) + static_cast(r); - if (res > INT_MAX) - return INT_MAX; - else if (res < INT_MIN) - return INT_MIN; + if (res > std::numeric_limits::max()) + return std::numeric_limits::max(); + else if (res < std::numeric_limits::min()) + return std::numeric_limits::min(); else return static_cast(res); } @@ -217,8 +221,8 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) { inputs["Input"] = std::vector({"input"}); inputs["Filter"] = std::vector({"filter"}); inputs["Scale"] = std::vector({"scale"}); - inputs["Y"] = std::vector({"y"}); - outputs["Output"] = std::vector({"output"}); + inputs["Y"] = std::vector({"bias"}); + outputs["Out"] = std::vector({"output"}); auto input_var = scope.get()->Var("input"); auto input = input_var->template GetMutable(); @@ -234,7 +238,7 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) { float scale_v = 0.000828f; scale->mutable_data()[0] = scale_v; - auto bias_var = scope.get()->Var("y"); + auto bias_var = scope.get()->Var("bias"); auto bias = bias_var->template GetMutable(); SetupTensor(bias, bias_shape, -127, 127); @@ -352,3 +356,5 @@ int main(int argc, char *argv[]) { paddle_mobile::TestConvOp(in_channels, in_height, in_width, out_channels); } + +#endif diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 83da418025a3b59de8c8e1a5edd59c8b3ed49e90..2734bbeace0271fee67467c8943340ace2767d86 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "../test_helper.h" #include "../test_include.h" #include "operators/mul_op.h"