From a80b04b9f7f2d04f6adb3a6979c8ae18bcc53f4d Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Wed, 28 Nov 2018 21:55:20 +0800 Subject: [PATCH] add fusion_conv_add_relu_int8_op and unit test. --- src/common/types.cpp | 2 + src/common/types.h | 1 + src/framework/op_registry.h | 18 + .../fusion_conv_add_relu_int8_op.cpp | 56 +++ src/operators/fusion_conv_add_relu_int8_op.h | 44 +++ .../kernel/arm/conv_add_relu_int8_kernel.cpp | 39 ++ .../central-arm-func/conv_add_relu_arm_func.h | 1 + .../conv_add_relu_int8_arm_func.h | 125 +++++++ .../kernel/conv_add_relu_int8_kernel.h | 45 +++ src/operators/math/gemm_int8.cpp | 3 + src/operators/op_param.h | 31 +- test/CMakeLists.txt | 4 + test/common/test_gemm_int8_accuracy.cpp | 17 +- test/common/test_gemm_perf.cpp | 4 +- .../test_fusion_conv_add_relu_int8_op.cpp | 354 ++++++++++++++++++ test/operators/test_mul_op.cpp | 6 +- tools/op.cmake | 4 + 17 files changed, 743 insertions(+), 11 deletions(-) create mode 100644 src/operators/fusion_conv_add_relu_int8_op.cpp create mode 100644 src/operators/fusion_conv_add_relu_int8_op.h create mode 100644 src/operators/kernel/arm/conv_add_relu_int8_kernel.cpp create mode 100644 src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h create mode 100644 src/operators/kernel/conv_add_relu_int8_kernel.h create mode 100644 test/operators/test_fusion_conv_add_relu_int8_op.cpp diff --git a/src/common/types.cpp b/src/common/types.cpp index ba00f639d7..6cea95546d 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -24,6 +24,7 @@ const char *G_OP_TYPE_CONCAT = "concat"; const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant"; const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; +const char *G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8 = "fusion_conv_add_relu_int8"; const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu"; const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu"; const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu"; @@ -111,6 +112,7 @@ std::unordered_map< {G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}}, {G_OP_TYPE_FILL_CONSTANT, {{}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}}, + {G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8, {{"Input"}, {"Output"}}}, {G_OP_TYPE_FUSION_CONV_ADD_PRELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_IM2SEQUENCE, {{"X"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index e9c0f81232..a1a9185733 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -99,6 +99,7 @@ extern const char *G_OP_TYPE_BOX_CODER; extern const char *G_OP_TYPE_CONCAT; extern const char *G_OP_TYPE_ELEMENTWISE_ADD; extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU; +extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8; extern const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU; extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU; extern const char *G_OP_TYPE_FC; diff --git a/src/framework/op_registry.h b/src/framework/op_registry.h index 219385ab14..52cae493ea 100644 --- a/src/framework/op_registry.h +++ b/src/framework/op_registry.h @@ -98,6 +98,24 @@ class OpRegistry { } }; +#define REGISTER_OPERATOR_INT8(op_type, op_class, device_name, device_type) \ + template class op_class; \ + template \ + class _OpClass_##op_type##_##device_name : public op_class { \ + public: \ + DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_##device_name, op_class); \ + }; \ + static paddle_mobile::framework::OperatorRegistrar< \ + device_type, _OpClass_##op_type##_##device_name> \ + __op_registrar_##op_type##_##device_name(#op_type); \ + int TouchOpRegistrar_##op_type##_##device_name() { \ + __op_registrar_##op_type##_##device_name.Touch(); \ + return 0; \ + } + +#define REGISTER_OPERATOR_CPU_INT8(op_type, op_class) \ + REGISTER_OPERATOR_INT8(op_type, op_class, cpu, paddle_mobile::CPU); + #define REGISTER_OPERATOR(op_type, op_class, device_name, device_type) \ template class op_class; \ template \ diff --git a/src/operators/fusion_conv_add_relu_int8_op.cpp b/src/operators/fusion_conv_add_relu_int8_op.cpp new file mode 100644 index 0000000000..ac0226ec7a --- /dev/null +++ b/src/operators/fusion_conv_add_relu_int8_op.cpp @@ -0,0 +1,56 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDRELU_INT8_OP + +#include "operators/fusion_conv_add_relu_int8_op.h" +#include +#include "operators/math/conv_func.h" + +namespace paddle_mobile { +namespace operators { + +template +void FusionConvAddReluInt8Op::InferShape() const { + auto in_dims = this->param_.Input()->dims(); + auto filter_dims = this->param_.Filter()->dims(); + const std::vector &strides = this->param_.Strides(); + std::vector paddings = this->param_.Paddings(); + int groups = this->param_.Groups(); + std::vector dilations = this->param_.Dilations(); + + PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && + dilations.size() == paddings.size() && + paddings.size() == strides.size()), + "ConvParam is not suitable"); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back( + math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], + paddings[i], strides[i])); + } + framework::DDim ddim = framework::make_ddim(output_shape); + this->param_.Output()->Resize(ddim); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU_INT8(fusion_conv_add_relu_int8, + ops::FusionConvAddReluInt8Op); +#endif +#endif // FUSION_CONVADDRELU_INT8_OP diff --git a/src/operators/fusion_conv_add_relu_int8_op.h b/src/operators/fusion_conv_add_relu_int8_op.h new file mode 100644 index 0000000000..c9ca511eaa --- /dev/null +++ b/src/operators/fusion_conv_add_relu_int8_op.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDRELU_INT8_OP +#pragma once +#include +#include "framework/operator.h" +#include "operators/kernel/conv_add_relu_int8_kernel.h" +#include "operators/op_param.h" +namespace paddle_mobile { +namespace operators { +using std::string; +template +class FusionConvAddReluInt8Op + : public framework::OperatorWithKernel< + DeviceType, FusionConvAddReluInt8Param, + operators::ConvAddReluInt8Kernel> { + public: + FusionConvAddReluInt8Op(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionConvAddReluInt8Param, + operators::ConvAddReluInt8Kernel>( + type, inputs, outputs, attrs, scope) {} + void InferShape() const override; + + protected: +}; +} // namespace operators +} // namespace paddle_mobile +#endif // FUSION_CONVADDRELU_INT8_OP diff --git a/src/operators/kernel/arm/conv_add_relu_int8_kernel.cpp b/src/operators/kernel/arm/conv_add_relu_int8_kernel.cpp new file mode 100644 index 0000000000..b73dcf0c02 --- /dev/null +++ b/src/operators/kernel/arm/conv_add_relu_int8_kernel.cpp @@ -0,0 +1,39 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDRELU_INT8_OP + +#include "operators/kernel/conv_add_relu_int8_kernel.h" +#include "operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ConvAddReluInt8Kernel::Init( + FusionConvAddReluInt8Param *param) { + return true; +} + +template <> +void ConvAddReluInt8Kernel::Compute( + const FusionConvAddReluInt8Param ¶m) { + ConvAddReluInt8Compute(param); +} +template class ConvAddReluInt8Kernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif // FUSION_CONVADDRELU_INT8_OP diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h index 36886b9e2c..9ea8dbf0c1 100644 --- a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h @@ -33,6 +33,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { int axis = param.Axis(); Tensor *output = param.Output(); float *biase_data = bias.data(); + output->mutable_data

(); int groups = param.Groups(); std::vector strides = param.Strides(); diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h new file mode 100644 index 0000000000..fb431dc279 --- /dev/null +++ b/src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h @@ -0,0 +1,125 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDRELU_INT8_OP + +#pragma once +#include +#include "operators/math/conv_func.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +void ConvAddReluInt8Compute(const FusionConvAddReluInt8Param ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor bias = *param.Bias(); + Tensor scale = *param.InputScale(); + int32_t axis = param.Axis(); + Tensor *output = param.Output(); + output->mutable_data

(); + + int32_t *biase_data = bias.data(); + float scale_v = scale.data()[0]; + + int32_t groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int32_t batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = + math::IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data

(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int32_t in_step = static_cast(input->dims()[1]) / groups; + int32_t out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int32_t i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + + for (int32_t g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + + math::matmul_int8(filter_slice, false, col_matrix, false, scale_v, + &out_slice, static_cast(0), true, biase_data); + } + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif // FUSION_CONVADDRELU_INT8_OP diff --git a/src/operators/kernel/conv_add_relu_int8_kernel.h b/src/operators/kernel/conv_add_relu_int8_kernel.h new file mode 100644 index 0000000000..ecd9f3d863 --- /dev/null +++ b/src/operators/kernel/conv_add_relu_int8_kernel.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDRELU_INT8_OP + +#pragma once + +#include +#include "framework/ddim.h" +#include "framework/operator.h" +#include "operators/math/conv_func.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using framework::DDim; +using framework::OpKernelBase; + +template +class ConvAddReluInt8Kernel + : public OpKernelBase> { + public: + void Compute(const FusionConvAddReluInt8Param ¶m); + bool Init(FusionConvAddReluInt8Param *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif // FUSION_CONVADDRELU_INT8_OP diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index 555672720f..d0de4d6f09 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -243,6 +243,9 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, #endif // __ARM_NEON } +// The core idea of AddDot4x2 function is borrowed from the Google's gemmlowp +// open source library. The address of gemmlowp is +// https://github.com/google/gemmlowp. void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON diff --git a/src/operators/op_param.h b/src/operators/op_param.h index b6597b55a9..ea79a3af2d 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -437,7 +437,7 @@ class ConvParam : public OpParam { #endif - private: + protected: RType *input_; mutable RType *output_; mutable RType *filter_; @@ -1709,6 +1709,35 @@ class FusionConvAddReluParam : public FusionConvAddParam { }; #endif +#ifdef FUSION_CONVADDRELU_INT8_OP +template +class FusionConvAddReluInt8Param : public ConvParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionConvAddReluInt8Param(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : ConvParam(inputs, outputs, attrs, scope) { + scale_ = OpParam::InputScaleFrom(inputs, scope); + bias_ = OpParam::InputYFrom(inputs, scope); + axis_ = OpParam::GetAttr("axis", attrs); + } + + const RType *InputScale() const { return scale_; } + + RType *Bias() const { return bias_; } + + const int &Axis() const { return axis_; } + + protected: + RType *scale_; + RType *bias_; + int axis_; +}; +#endif + #ifdef FUSION_CONVADDPRELU_OP template class FusionConvAddPReluParam : public ConvParam { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bfd125ce5b..0f489995cb 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -324,6 +324,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-conv-add-relu-op operators/test_conv_add_relu_op.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-conv-add-relu-op paddle-mobile) + # gen test + ADD_EXECUTABLE(test-conv-add-relu-int8-op operators/test_fusion_conv_add_relu_int8_op.cpp test_helper.h test_include.h) + target_link_libraries(test-conv-add-relu-int8-op paddle-mobile) + # gen test ADD_EXECUTABLE(test-conv-add-bn-relu-op operators/test_fusion_conv_add_bn_relu_op.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-conv-add-bn-relu-op paddle-mobile) diff --git a/test/common/test_gemm_int8_accuracy.cpp b/test/common/test_gemm_int8_accuracy.cpp index f276cad8e6..9120a9c7fa 100644 --- a/test/common/test_gemm_int8_accuracy.cpp +++ b/test/common/test_gemm_int8_accuracy.cpp @@ -65,12 +65,19 @@ int32_t qadd_int32(int32_t l, int32_t r) { return static_cast(res); } +// round to zero +float round2zero(float v) { + float res; + if (v > 0) + res = std::floor(v); + else if (v < 0) + res = std::ceil(v); + return res; +} + int8_t qscale_int32(int32_t v, float scale) { float res = static_cast(v) * scale; - if (res > 0) - res = std::floor(res); - else if (res < 0) - res = std::ceil(res); // round to zero + res = round2zero(res); if (res > 127) return static_cast(127); else if (res < -127) @@ -155,7 +162,7 @@ int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr) { int lda = k; int ldb = n; int ldc = n; - float scale = 0.00628; + float scale = 0.00628f; default_random_engine e; uniform_int_distribution pixel(-127, 127); int8_t *a = static_cast( diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 5ca0b40cfc..5c5f4026fd 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -103,13 +103,13 @@ int main() { // warm-up 10 times for (int j = 0; j < 10; ++j) { paddle_mobile::operators::math::matmul_int8( - aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), true, &bias_data[0]); } auto time5 = time(); for (int j = 0; j < 10; ++j) { paddle_mobile::operators::math::matmul_int8( - aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), true, &bias_data[0]); } auto time6 = time(); diff --git a/test/operators/test_fusion_conv_add_relu_int8_op.cpp b/test/operators/test_fusion_conv_add_relu_int8_op.cpp new file mode 100644 index 0000000000..4c80f9c449 --- /dev/null +++ b/test/operators/test_fusion_conv_add_relu_int8_op.cpp @@ -0,0 +1,354 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../test_helper.h" +#include "../test_include.h" +#include "operators/fusion_conv_add_relu_int8_op.h" + +namespace paddle_mobile { +int32_t qadd_int32(int32_t l, int32_t r) { + int64_t res = static_cast(l) + static_cast(r); + if (res > INT_MAX) + return INT_MAX; + else if (res < INT_MIN) + return INT_MIN; + else + return static_cast(res); +} + +// round to zero +float round2zero(float v) { + float res; + if (v > 0) + res = std::floor(v); + else if (v < 0) + res = std::ceil(v); + return res; +} + +int8_t qscale_int32(int32_t v, float scale) { + float res = static_cast(v) * scale; + res = round2zero(res); + if (res > 127) + return static_cast(127); + else if (res < -127) + return static_cast(-127); + else + return static_cast(res); +} + +// Reference convolution from Caffe for checking results. +// accumulate through explicit loops over input, output, and filters. +template +void conv2d(const framework::Tensor *input, const framework::Tensor *filter, + const framework::Tensor *bias, const framework::AttributeMap &attrs, + framework::Tensor *output, float scale) { + framework::AttrReader attr_reader(attrs); + std::vector paddings = attr_reader.Get>("paddings"); + std::vector strides = attr_reader.Get>("strides"); + std::vector dilations = attr_reader.Get>("dilations"); + int groups = attr_reader.Get("groups"); + int kernel_h = filter->dims()[2]; + int kernel_w = filter->dims()[3]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int dilation_h = dilations[0]; + int dilation_w = dilations[1]; + auto in_shape = input->dims(); + auto out_shape = output->dims(); + + const bool has_depth = 0; + int kernel_d, pad_d, stride_d, dilation_d; + if (has_depth) { + kernel_d = kernel_h; + stride_d = stride_h; + pad_d = pad_h; + dilation_d = dilation_h; + } else { + kernel_d = stride_d = dilation_d = 1; + pad_d = 0; + } + // Groups + int o_g = out_shape[1] / groups; + int k_g = in_shape[1] / groups; + int o_head, k_head; + // Convolution + vector weight_offset(4 + has_depth); + vector in_offset(4 + has_depth); + vector out_offset(4 + has_depth); + auto offset = [](const framework::Tensor *input, const vector &indics) { + framework::DDim shape = input->dims(); + size_t count = 0; + for (int i = 0; i < indics.size(); ++i) { + count *= shape[i]; + count += indics[i]; + } + return count; + }; + + const T *in_data = input->data(); + const T *w_data = filter->data(); + framework::Tensor output_32; + int32_t *out_data_32 = output_32.mutable_data(out_shape); + memset(out_data_32, 0, output_32.numel() * sizeof(int32_t)); + for (int n = 0; n < out_shape[0]; n++) { + for (int g = 0; g < groups; g++) { + o_head = o_g * g; + k_head = k_g * g; + for (int o = 0; o < o_g; o++) { + for (int k = 0; k < k_g; k++) { + for (int z = 0; z < (has_depth ? out_shape[2] : 1); z++) { + for (int y = 0; y < out_shape[2 + has_depth]; y++) { + for (int x = 0; x < out_shape[3 + has_depth]; x++) { + for (int r = 0; r < kernel_d; r++) { + for (int p = 0; p < kernel_h; p++) { + for (int q = 0; q < kernel_w; q++) { + int in_z = z * stride_d - pad_d + r * dilation_d; + int in_y = y * stride_h - pad_h + p * dilation_h; + int in_x = x * stride_w - pad_w + q * dilation_w; + if (in_z >= 0 && in_z < (has_depth ? in_shape[2] : 1) && + in_y >= 0 && in_y < in_shape[2 + has_depth] && + in_x >= 0 && in_x < in_shape[3 + has_depth]) { + weight_offset[0] = o + o_head; + weight_offset[1] = k; + if (has_depth) { + weight_offset[2] = r; + } + weight_offset[2 + has_depth] = p; + weight_offset[3 + has_depth] = q; + in_offset[0] = n; + in_offset[1] = k + k_head; + if (has_depth) { + in_offset[2] = in_z; + } + in_offset[2 + has_depth] = in_y; + in_offset[3 + has_depth] = in_x; + out_offset[0] = n; + out_offset[1] = o + o_head; + if (has_depth) { + out_offset[2] = z; + } + out_offset[2 + has_depth] = y; + out_offset[3 + has_depth] = x; + + out_data_32[offset(output, out_offset)] += + in_data[offset(input, in_offset)] * + w_data[offset(filter, weight_offset)]; + } + } + } + } + } + } + } + } + } + } + } + + T *out_data = output->mutable_data(); + int32_t n = out_shape[0]; + int32_t c = out_shape[1]; + int32_t h = out_shape[2]; + int32_t w = out_shape[3]; + const int32_t *bias_data = bias->data(); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < c; ++j) { + int32_t bias_v = bias_data[j]; + for (int k = 0; k < h; ++k) { + for (int l = 0; l < w; ++l) { + int32_t tmp = out_data_32[i * c * h * w + j * h * w + k * w + l]; + tmp = qadd_int32(tmp, bias_v); + tmp = std::max(0, tmp); + out_data[i * c * h * w + j * h * w + k * w + l] = + qscale_int32(tmp, scale); + } + } + } + } +} + +template +int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) { + int kernel_h = Kernel; + int kernel_w = Kernel; + int pad_h = Pad; + int pad_w = Pad; + int stride_h = Stride; + int stride_w = Stride; + int dilation_h = 1; + int dilation_w = 1; + + int batch_size = 1; + int input_c = in_channels; + int input_h = in_height; + int input_w = in_width; + int output_c = out_channels; + framework::DDim input_shape = + framework::make_ddim({batch_size, input_c, input_h, input_w}); + framework::DDim filter_shape = + framework::make_ddim({output_c, input_c, kernel_h, kernel_w}); + + int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; + int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1; + int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1; + framework::DDim output_shape = framework::make_ddim( + std::vector({batch_size, output_c, output_h, output_w})); + + framework::DDim bias_shape = framework::make_ddim({output_c}); + + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["Input"] = std::vector({"input"}); + inputs["Filter"] = std::vector({"filter"}); + inputs["Scale"] = std::vector({"scale"}); + inputs["Y"] = std::vector({"y"}); + outputs["Output"] = std::vector({"output"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, input_shape, -127, 127); + + auto filter_var = scope.get()->Var("filter"); + auto filter = filter_var->template GetMutable(); + SetupTensor(filter, filter_shape, -127, 127); + + auto scale_var = scope.get()->Var("scale"); + auto scale = scale_var->template GetMutable(); + scale->Resize(framework::make_ddim({1})); + float scale_v = 0.000828f; + scale->mutable_data()[0] = scale_v; + + auto bias_var = scope.get()->Var("y"); + auto bias = bias_var->template GetMutable(); + SetupTensor(bias, bias_shape, -127, 127); + + auto output_var = scope.get()->Var("output"); + framework::AttributeMap attrs; + attrs["strides"].Set>(std::vector({stride_h, stride_w})); + attrs["paddings"].Set>(std::vector({pad_h, pad_w})); + attrs["dilations"].Set>( + std::vector({dilation_h, dilation_w})); + attrs["groups"].Set(1); + attrs["axis"].Set(0); + + auto *op = new operators::FusionConvAddReluInt8Op( + "fusion_conv_add_relu_int8", inputs, outputs, attrs, scope); + op->InferShape(); + op->Init(); + op->Run(); + + framework::Tensor output_cmp; + output_cmp.mutable_data(output_shape); + conv2d(input, filter, bias, attrs, &output_cmp, scale_v); + + // compare results + int eq = 0; + int neq = 0; + auto output = output_var->template Get(); + const T *output_data = output->data(); + T *output_cmp_data = output_cmp.data(); + for (int i = 0; i < output->numel(); ++i) { + PADDLE_MOBILE_ENFORCE( + output_data[i] == output_cmp_data[i], + "The execution of test_fusion_conv_add_relu_int8_op is failed!"); + if (output_data[i] == output_cmp_data[i]) { + ++eq; + } else { + ++neq; + } + } + std::cout << "eq = " << eq << ", neq = " << neq << std::endl; + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main(int argc, char *argv[]) { + if (argc < 5) { + LOG(paddle_mobile::kLOG_INFO) + << "Usage:\n" + << " ./test-conv-add-relu-int8-op in_channels in_height in_width " + "out_channels\n" + << " params:\n" + << " -in_channels: int, input image's channels\n" + << " -in_height: int, input image's height\n" + << " -in_width: int, input image's width\n" + << " -out_channels: int, conv output channels\n"; + return 1; + } + int in_channels = atoi(argv[1]); + int in_height = atoi(argv[2]); + int in_width = atoi(argv[3]); + int out_channels = atoi(argv[4]); + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8_t, kernel=3, pad=1, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 0, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 1, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 3, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 3, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 5, stride = 3 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 7, pad = 3, stride = 4 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 3, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + + // kernel = 5, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); + + // kernel = 5, pad = 2, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, in_width, + out_channels); +} diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 262ee960e1..83da418025 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -79,14 +79,14 @@ int TestMulOP() { PADDLE_MOBILE_ENFORCE( output_data[i] == c[i], "output[%d] = %d, output_cmp[%d] = %d", i, static_cast(output_data[i]), i, static_cast(c[i])); - if (static_cast(output_data[i] == c[i])) { + if (output_data[i] == c[i]) { ++eq; } else { ++neq; } } - DLOG << "mnk=" << m << " " << n << " " << k << " eq=" << eq - << " neq=" << neq; + std::cout << "mnk=" << m << " " << n << " " << k << " eq=" << eq + << " neq=" << neq << std::endl; delete op; return 0; } diff --git a/tools/op.cmake b/tools/op.cmake index 3a4a0597a4..45dbcdcf05 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -213,6 +213,7 @@ if(NOT FOUND_MATCH) set(FUSION_CONVADD_OP ON) set(FUSION_CONVADDPRELU_OP ON) set(FUSION_CONVADDRELU_OP ON) + set(FUSION_CONVADDRELU_INT8_OP ON) set(FUSION_FC_OP ON) set(LRN_OP ON) set(MUL_OP ON) @@ -306,6 +307,9 @@ endif() if (FUSION_CONVADDRELU_OP) add_definitions(-DFUSION_CONVADDRELU_OP) endif() +if (FUSION_CONVADDRELU_INT8_OP) + add_definitions(-DFUSION_CONVADDRELU_INT8_OP) +endif() if (FUSION_CONVADDPRELU_OP) add_definitions(-DFUSION_CONVADDPRELU_OP) endif() -- GitLab