diff --git a/src/common/types.cpp b/src/common/types.cpp index 36c93046c1c09d0ec5043ef9a7514dedf212e738..312e491a35681e2fc75584106160a4c79e22e372 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum"; const char *G_OP_TYPE_QUANTIZE = "quantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; +const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN = "fusion_dequant_add_bn"; +const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU = "fusion_dequant_bn_relu"; const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu"; const char *G_OP_TYPE_TANH = "tanh"; @@ -136,6 +138,8 @@ std::unordered_map< {G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, + {G_OP_TYPE_FUSION_DEQUANT_ADD_BN, {{"X", "Scale"}, {"Y"}}}, + {G_OP_TYPE_FUSION_DEQUANT_BN_RELU, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index 5704618c9e475781b22df4ae3a0ac3a994eb8c90..16ed1aef57432249b14c415b3a23042ca295b600 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -138,6 +138,8 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL; extern const char *G_OP_TYPE_QUANTIZE; extern const char *G_OP_TYPE_DEQUANTIZE; +extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN; +extern const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU; extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; extern const char *G_OP_TYPE_TANH; diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 135ef9083e42271fe63cdc29ee53e876f532c287..2534217d58674f912f0e5da741dfcae41827edf1 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -233,6 +233,14 @@ LOAD_OP1(quantize, CPU); #ifdef DEQUANT_OP LOAD_OP1(dequantize, CPU); #endif +#ifdef FUSION_DEQUANT_ADD_BN_OP +LOAD_OP1(fusion_dequant_add_bn, CPU); +LOAD_FUSION_MATCHER(fusion_dequant_add_bn); +#endif +#ifdef FUSION_DEQUANT_BN_RELU_OP +LOAD_OP1(fusion_dequant_bn_relu, CPU); +LOAD_FUSION_MATCHER(fusion_dequant_bn_relu); +#endif #ifdef FUSION_DEQUANT_ADD_BN_RELU_OP LOAD_OP1(fusion_dequant_add_bn_relu, CPU); LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu); diff --git a/src/io/ios_io/PaddleMobileCPU.mm b/src/io/ios_io/PaddleMobileCPU.mm index 2416c0d4e708813f8abf18c9dcb6e5d8b3c37a90..209022b64e90f700dc83c43d11f6e619c66673b6 100644 --- a/src/io/ios_io/PaddleMobileCPU.mm +++ b/src/io/ios_io/PaddleMobileCPU.mm @@ -95,7 +95,8 @@ static std::mutex shared_mutex; andModelParamsLen:(size_t)combinedParamsLen andCombinedParamsBuf:(const uint8_t *)combinedParamsBuf { pam_->SetThreadNum(2); - return loaded_ = pam_->LoadCombinedMemory(modelLen, modelBuf, combinedParamsLen, combinedParamsBuf); + return loaded_ = pam_->LoadCombinedMemory(modelLen, modelBuf, combinedParamsLen, + const_cast(combinedParamsBuf)); } - (BOOL)load:(NSString *)modelAndWeightPath{ diff --git a/src/operators/depthwise_conv_op.h b/src/operators/depthwise_conv_op.h index 102d65670d3e50acd15745e95b85d7b843994ed7..26253e0e0a7d3c52808a691d4257e7074e1da6e2 100644 --- a/src/operators/depthwise_conv_op.h +++ b/src/operators/depthwise_conv_op.h @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "framework/operator.h" -#include "operators/kernel/depthwise_conv_kernel.h" +#include "operators/kernel/conv_kernel.h" namespace paddle_mobile { namespace operators { @@ -26,19 +26,16 @@ namespace operators { template class DepthwiseConvOp : public framework::OperatorWithKernel< DeviceType, ConvParam, - operators::DepthwiseConvKernel> { + operators::ConvKernel> { public: DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, std::shared_ptr scope) - : framework::OperatorWithKernel< - DeviceType, ConvParam, - operators::DepthwiseConvKernel>( + : framework::OperatorWithKernel, + operators::ConvKernel>( type, inputs, outputs, attrs, scope) {} void InferShape() const override; - - private: }; } // namespace operators diff --git a/src/operators/kernel/arm/depthwise_conv_kernel.cpp b/src/operators/fusion_dequant_add_bn_op.cpp similarity index 58% rename from src/operators/kernel/arm/depthwise_conv_kernel.cpp rename to src/operators/fusion_dequant_add_bn_op.cpp index 000d59baa8c804201cbd2e2a731c2077196b698f..4df50af22b0dc9e214b0cabe303bf70edf50c307 100644 --- a/src/operators/kernel/arm/depthwise_conv_kernel.cpp +++ b/src/operators/fusion_dequant_add_bn_op.cpp @@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef DEPTHWISECONV_OP +#ifdef FUSION_DEQUANT_ADD_BN_OP -#include "operators/kernel/depthwise_conv_kernel.h" -#include "operators/kernel/central-arm-func/depthwise_conv_arm_func.h" +#include "operators/fusion_dequant_add_bn_op.h" namespace paddle_mobile { namespace operators { -template <> -bool DepthwiseConvKernel::Init(ConvParam *param) { - return true; +template +void FusionDequantAddBNOp::InferShape() const { + const auto& input_dims = this->param_.input_->dims(); + this->param_.output_->Resize(input_dims); } -template <> -void DepthwiseConvKernel::Compute(const ConvParam ¶m) { - DepthwiseConvCompute(param); -} - -template class DepthwiseConvKernel; - } // namespace operators } // namespace paddle_mobile +namespace ops = paddle_mobile::operators; +REGISTER_FUSION_MATCHER(fusion_dequant_add_bn, ops::FusionDequantAddBNMatcher); + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_dequant_add_bn, ops::FusionDequantAddBNOp); +#endif + #endif diff --git a/src/operators/fusion_dequant_add_bn_op.h b/src/operators/fusion_dequant_add_bn_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8c4f353a81705c41c75a5aff92f2637b92755a2c --- /dev/null +++ b/src/operators/fusion_dequant_add_bn_op.h @@ -0,0 +1,74 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/dequant_add_bn_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +class FusionDequantAddBNMatcher : public framework::FusionOpMatcher { + public: + FusionDequantAddBNMatcher() { + node_ = framework::Node(G_OP_TYPE_DEQUANTIZE); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_BATCHNORM); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}, + {G_OP_TYPE_BATCHNORM, + {{"Scale", "BNScale"}, + {"Mean", "BNMean"}, + {"Bias", "BNBias"}, + {"Variance", "BNVariance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN; } +}; + +template +class FusionDequantAddBNOp + : public framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNParam, + operators::FusionDequantAddBNKernel> { + public: + FusionDequantAddBNOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNParam, + operators::FusionDequantAddBNKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/fusion_dequant_add_bn_relu_op.h b/src/operators/fusion_dequant_add_bn_relu_op.h index dbd9ad0de2ece751ffd4da05cb09f0091a5755aa..b33d3c210ca56f27b769789fee08023ebb8c80de 100644 --- a/src/operators/fusion_dequant_add_bn_relu_op.h +++ b/src/operators/fusion_dequant_add_bn_relu_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include #include "framework/operator.h" #include "framework/program/program-optimize/fusion_op_register.h" -#include "operators/kernel/dequant_add_bn_relu_kernel.h" +#include "operators/kernel/dequant_bn_relu_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { diff --git a/src/operators/kernel/depthwise_conv_kernel.h b/src/operators/fusion_dequant_bn_relu_op.cpp similarity index 56% rename from src/operators/kernel/depthwise_conv_kernel.h rename to src/operators/fusion_dequant_bn_relu_op.cpp index 3ee5bf86e97baa3970239e32b7fd5fc341e09f92..c843889a61a128c86915b14b0229ed172df2325b 100644 --- a/src/operators/kernel/depthwise_conv_kernel.h +++ b/src/operators/fusion_dequant_bn_relu_op.cpp @@ -12,29 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef DEPTHWISECONV_OP +#ifdef FUSION_DEQUANT_BN_RELU_OP -#pragma once - -#include "framework/operator.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" +#include "operators/fusion_dequant_bn_relu_op.h" namespace paddle_mobile { namespace operators { -using framework::OpKernelBase; +template +void FusionDequantBNReluOp::InferShape() const { + const auto& input_dims = this->param_.input_->dims(); + this->param_.output_->Resize(input_dims); +} -template -class DepthwiseConvKernel - : public OpKernelBase> { - public: - void Compute(const ConvParam ¶m); - bool Init(ConvParam *param); -}; } // namespace operators } // namespace paddle_mobile +namespace ops = paddle_mobile::operators; +REGISTER_FUSION_MATCHER(fusion_dequant_bn_relu, + ops::FusionDequantBNReluMatcher); + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_dequant_bn_relu, ops::FusionDequantBNReluOp); +#endif + #endif diff --git a/src/operators/fusion_dequant_bn_relu_op.h b/src/operators/fusion_dequant_bn_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b556df1e3707736be0eaf58eb8323cdbb64cbd74 --- /dev/null +++ b/src/operators/fusion_dequant_bn_relu_op.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_BN_RELU_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/dequant_bn_relu_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +class FusionDequantBNReluMatcher : public framework::FusionOpMatcher { + public: + FusionDequantBNReluMatcher() { + node_ = framework::Node(G_OP_TYPE_DEQUANTIZE); + node_ > std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_BATCHNORM, + {{"Scale", "BNScale"}, + {"Mean", "BNMean"}, + {"Bias", "BNBias"}, + {"Variance", "BNVariance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_BN_RELU; } +}; + +template +class FusionDequantBNReluOp + : public framework::OperatorWithKernel< + DeviceType, FusionDequantBNReluParam, + operators::FusionDequantBNReluKernel> { + public: + FusionDequantBNReluOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionDequantBNReluParam, + operators::FusionDequantBNReluKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index 840be6c67d2e350c914a7d8aa8e9a32acdd00fb1..5384faf2b8ae0e0fe6aed1b6c0cd7d4d16978ac9 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -22,41 +22,43 @@ namespace operators { template <> bool ConvKernel::Init(ConvParam *param) { + bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Filter()->dims()[2] == 3; + bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] && + param->Input()->dims()[1] == param->Output()->dims()[1]; if (param->Filter()->type() == typeid(int8_t)) { - if (param->Groups() == param->Input()->dims()[1] && - param->Input()->dims()[1] == param->Output()->dims()[1] && - param->Filter()->dims()[2] == param->Filter()->dims()[3] && - param->Filter()->dims()[2] == 3 && param->Strides()[0] < 3 && + if (depth3x3 && param->Strides()[0] < 3 && param->Strides()[0] == param->Strides()[1]) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_INT8; } else { param->ExecMode() = ConvParam::EXEC_GEMM_INT8; } } else { - if (param->Groups() == param->Input()->dims()[1] && - param->Input()->dims()[1] == param->Output()->dims()[1] && - param->Filter()->dims()[2] == param->Filter()->dims()[3] && - param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1) { + if (depth3x3 && param->Strides()[0] == param->Strides()[1] && + param->Strides()[0] == 1 && param->Paddings()[0] == 1 && + param->Paddings()[0] == param->Paddings()[1]) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT; - } else if (param->Groups() == param->Input()->dims()[1] && - param->Input()->dims()[1] == param->Output()->dims()[1] && - param->Filter()->dims()[2] == param->Filter()->dims()[3] && - param->Filter()->dims()[2] == 3) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_FLOAT; + } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && + param->Strides()[0] == 2 && param->Paddings()[0] == 0 && + param->Paddings()[0] == param->Paddings()[1]) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT; + } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && + param->Strides()[0] == 2 && param->Paddings()[0] == 1 && + param->Paddings()[0] == param->Paddings()[1]) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT; #ifndef __aarch64__ - } else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] && - param->Strides()[0] == param->Strides()[1] && + } else if (conv3x3 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && - param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1 && - param->Dilations()[0] == 1 && param->Output()->dims()[1] >= 16 && + param->Strides()[0] == 1 && param->Dilations()[0] == 1 && + param->Output()->dims()[1] >= 16 && param->Input()->dims()[1] >= 16 && param->Input()->dims()[2] <= 140 /* refered from ncnn */) { param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT; // transform weight - framework::Tensor *transformed_weight = new framework::Tensor; + framework::Tensor transformed_weight; operators::math::winograd_transform_weight<8, 3>(*param->Filter(), - transformed_weight); - param->Filter() = transformed_weight; + &transformed_weight); + framework::TensorCopy(transformed_weight, param->Filter()); #endif } else { param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; @@ -78,9 +80,13 @@ void ConvKernel::Compute(const ConvParam ¶m) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), nullptr, false); break; - case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); + case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: + math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), nullptr, false); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: + math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), + nullptr, false); break; case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); diff --git a/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/dequant_add_bn_kernel.cpp similarity index 86% rename from src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp rename to src/operators/kernel/arm/dequant_add_bn_kernel.cpp index bfe1935c216f94d660997b1bfa42f18e63295992..65fb0190f76a34a584d065bd43841567e9658bb8 100644 --- a/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/dequant_add_bn_kernel.cpp @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +#ifdef FUSION_DEQUANT_ADD_BN_OP -#include "operators/kernel/dequant_add_bn_relu_kernel.h" +#include "operators/kernel/dequant_add_bn_kernel.h" #include #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include @@ -24,8 +24,8 @@ namespace paddle_mobile { namespace operators { template <> -bool FusionDequantAddBNReluKernel::Init( - FusionDequantAddBNReluParam *param) { +bool FusionDequantAddBNKernel::Init( + FusionDequantAddBNParam *param) { // elementwise add params const Tensor *bias = param->bias_; // batch norm params @@ -49,8 +49,8 @@ bool FusionDequantAddBNReluKernel::Init( } template <> -void FusionDequantAddBNReluKernel::Compute( - const FusionDequantAddBNReluParam ¶m) { +void FusionDequantAddBNKernel::Compute( + const FusionDequantAddBNParam ¶m) { const int32_t *input = param.input_->data(); const float *bn_scale = param.bn_scale_->data(); const float *bn_bias = param.bn_bias_->data(); @@ -78,7 +78,6 @@ void FusionDequantAddBNReluKernel::Compute( remain = spatial_size & 0xF; float32x4_t __scale = vdupq_n_f32(scale); float32x4_t __bias = vdupq_n_f32(bias); - float32x4_t __zero = vdupq_n_f32(0.f); for (int k = 0; k < loop; ++k, x += 16, y += 16) { int32x4_t r0 = vld1q_s32(x); @@ -93,10 +92,6 @@ void FusionDequantAddBNReluKernel::Compute( f1 = vmlaq_f32(__bias, __scale, f1); f2 = vmlaq_f32(__bias, __scale, f2); f3 = vmlaq_f32(__bias, __scale, f3); - f0 = vmaxq_f32(__zero, f0); - f1 = vmaxq_f32(__zero, f1); - f2 = vmaxq_f32(__zero, f2); - f3 = vmaxq_f32(__zero, f3); vst1q_f32(y, f0); vst1q_f32(y + 4, f1); vst1q_f32(y + 8, f2); @@ -104,7 +99,7 @@ void FusionDequantAddBNReluKernel::Compute( } #endif // __ARM_NEON__ for (int k = 0; k < remain; ++k) { - y[k] = std::max(scale * x[k] + bias, 0.f); + y[k] = scale * x[k] + bias; } } } @@ -113,4 +108,4 @@ void FusionDequantAddBNReluKernel::Compute( } // namespace operators } // namespace paddle_mobile -#endif // FUSION_DEQUANT_ADD_BN_RELU_OP +#endif // FUSION_DEQUANT_ADD_BN_OP diff --git a/src/operators/kernel/arm/dequant_bn_relu_kernel.cpp b/src/operators/kernel/arm/dequant_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d656712c193aa81a8be11c53856c868e2b82483 --- /dev/null +++ b/src/operators/kernel/arm/dequant_bn_relu_kernel.cpp @@ -0,0 +1,150 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/kernel/dequant_bn_relu_kernel.h" +#include +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif + +namespace paddle_mobile { +namespace operators { + +#if defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_RELU_OP) +void DequantBNReluCompute(const FusionDequantBNParam *param) { + const int32_t *input = param->input_->data(); + const float *bn_scale = param->bn_scale_->data(); + const float *bn_bias = param->bn_bias_->data(); + // dequantize params + const float activation_scale = param->activation_scale_->data()[0]; + const float weight_scale = param->weight_scale_; + const float dequant_scale = activation_scale / weight_scale; + + float *output = param->output_->mutable_data(); + int batch_size = param->input_->dims()[0]; + int channels = param->input_->dims()[1]; + size_t spatial_size = param->input_->dims()[2] * param->input_->dims()[3]; + + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < batch_size; ++batch) { + for (int c = 0; c < channels; ++c) { + float scale = bn_scale[c] * dequant_scale; + float bias = bn_bias[c]; + size_t offset = (batch * channels + c) * spatial_size; + const int32_t *x = input + offset; + float *y = output + offset; + size_t remain = spatial_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = spatial_size >> 4; + remain = spatial_size & 0xF; + float32x4_t __scale = vdupq_n_f32(scale); + float32x4_t __bias = vdupq_n_f32(bias); + float32x4_t __zero = vdupq_n_f32(0.f); + + for (int k = 0; k < loop; ++k, x += 16, y += 16) { + int32x4_t r0 = vld1q_s32(x); + int32x4_t r1 = vld1q_s32(x + 4); + int32x4_t r2 = vld1q_s32(x + 8); + int32x4_t r3 = vld1q_s32(x + 12); + float32x4_t f0 = vcvtq_f32_s32(r0); + float32x4_t f1 = vcvtq_f32_s32(r1); + float32x4_t f2 = vcvtq_f32_s32(r2); + float32x4_t f3 = vcvtq_f32_s32(r3); + f0 = vmlaq_f32(__bias, __scale, f0); + f1 = vmlaq_f32(__bias, __scale, f1); + f2 = vmlaq_f32(__bias, __scale, f2); + f3 = vmlaq_f32(__bias, __scale, f3); + f0 = vmaxq_f32(__zero, f0); + f1 = vmaxq_f32(__zero, f1); + f2 = vmaxq_f32(__zero, f2); + f3 = vmaxq_f32(__zero, f3); + vst1q_f32(y, f0); + vst1q_f32(y + 4, f1); + vst1q_f32(y + 8, f2); + vst1q_f32(y + 12, f3); + } +#endif // __ARM_NEON__ + for (int k = 0; k < remain; ++k) { + y[k] = std::max(scale * x[k] + bias, 0.f); + } + } + } +} +#endif + +#ifdef FUSION_DEQUANT_BN_RELU_OP +template <> +bool FusionDequantBNReluKernel::Init( + FusionDequantBNReluParam *param) { + // batch norm params + const Tensor *bn_mean = param->bn_mean_; + const Tensor *bn_variance = param->bn_variance_; + Tensor *bn_scale = param->bn_scale_; + Tensor *bn_bias = param->bn_bias_; + const float epsilon = param->epsilon_; + + const float *mean_ptr = bn_mean->data(); + const float *var_ptr = bn_variance->data(); + float *bn_scale_ptr = bn_scale->mutable_data(); + float *bn_bias_ptr = bn_bias->mutable_data(); + for (int c = 0; c < bn_scale->numel(); ++c) { + float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon)); + bn_scale_ptr[c] = inv_scale; + bn_bias_ptr[c] = bn_bias_ptr[c] - inv_scale * mean_ptr[c]; + } + return true; +} + +template <> +void FusionDequantBNReluKernel::Compute( + const FusionDequantBNReluParam ¶m) { + DequantBNReluCompute(¶m); +} +#endif // FUSION_DEQUANT_BN_RELU_OP + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +template <> +bool FusionDequantAddBNReluKernel::Init( + FusionDequantAddBNReluParam *param) { + // elementwise add params + const Tensor *bias = param->bias_; + // batch norm params + const Tensor *bn_mean = param->bn_mean_; + const Tensor *bn_variance = param->bn_variance_; + Tensor *bn_scale = param->bn_scale_; + Tensor *bn_bias = param->bn_bias_; + const float epsilon = param->epsilon_; + + const float *bias_ptr = bias->data(); + const float *mean_ptr = bn_mean->data(); + const float *var_ptr = bn_variance->data(); + float *bn_scale_ptr = bn_scale->mutable_data(); + float *bn_bias_ptr = bn_bias->mutable_data(); + for (int c = 0; c < bn_scale->numel(); ++c) { + float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon)); + bn_scale_ptr[c] = inv_scale; + bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c]; + } + return true; +} + +template <> +void FusionDequantAddBNReluKernel::Compute( + const FusionDequantAddBNReluParam ¶m) { + DequantBNReluCompute(¶m); +} +#endif // FUSION_DEQUANT_ADD_BN_RELU_OP + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index 1e7623436a1a73644aca61e4634a7cd405bd64ad..ca3fa71f98f778752ac9dd7728385f5525696a02 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -20,6 +20,9 @@ limitations under the License. */ #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include +namespace paddle_mobile { +namespace operators { + #ifndef __aarch64__ inline float32_t vmaxvq_f32(float32x4_t r) { float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r)); @@ -27,9 +30,13 @@ inline float32_t vmaxvq_f32(float32x4_t r) { } #endif -inline int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); } +template +inline int32x4_t vround_f32(float32x4_t r) { + return vcvtq_s32_f32(r); +} -inline int32x4_t vrnd_away_zero(float32x4_t r) { +template <> +inline int32x4_t vround_f32(float32x4_t r) { float32x4_t plus = vdupq_n_f32(0.5); float32x4_t minus = vdupq_n_f32(-0.5); float32x4_t zero = vdupq_n_f32(0); @@ -40,31 +47,13 @@ inline int32x4_t vrnd_away_zero(float32x4_t r) { return ret; } -inline int32x4_t vrnd_to_even(float32x4_t r) { -#if 0 - int32x4_t ret; - float value[4]; - vst1q_f32(value, r); - for (int i = 0; i < 4; ++i) { - float v = round(value[i]); - int32_t q = (int32_t)v; - if (abs(abs(v - value[i]) - 0.5) > 0) { - ret[i] = q; - } else { - if (abs(q) % 2 == 0) { - ret[i] = q; - } else { - ret[i] = q + ((q > 0) ? -1 : 1); - } - } - } - return ret; -#else +template <> +inline int32x4_t vround_f32(float32x4_t r) { float32x4_t point5 = vdupq_n_f32(0.5); int32x4_t one = vdupq_n_s32(1); int32x4_t zero = vdupq_n_s32(0); - int32x4_t rnd = vrnd_away_zero(r); + int32x4_t rnd = vround_f32(r); float32x4_t frnd = vcvtq_f32_s32(rnd); frnd = vsubq_f32(frnd, r); frnd = vabsq_f32(frnd); @@ -82,115 +71,39 @@ inline int32x4_t vrnd_to_even(float32x4_t r) { smask = vsubq_s32(smask, one); rnd = vaddq_s32(rnd, smask); return rnd; -#endif } - -namespace paddle_mobile { -namespace operators { - -static float find_abs_max(const Tensor *input) { - float max_abs = 0.f; - const float *x = input->data(); - size_t size = input->numel(); -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - size_t loop = size >> 4; - size_t remain = size & 0xF; - for (size_t i = 0; i < loop; ++i) { - float32x4_t max; - float32x4_t r0 = vld1q_f32(x); - float32x4_t r1 = vld1q_f32(x + 4); - float32x4_t r2 = vld1q_f32(x + 8); - float32x4_t r3 = vld1q_f32(x + 12); - r0 = vabsq_f32(r0); - r1 = vabsq_f32(r1); - r2 = vabsq_f32(r2); - r3 = vabsq_f32(r3); - max[0] = vmaxvq_f32(r0); - max[1] = vmaxvq_f32(r1); - max[2] = vmaxvq_f32(r2); - max[3] = vmaxvq_f32(r3); - max[0] = vmaxvq_f32(max); - if (max[0] > max_abs) { - max_abs = max[0]; - } - x += 16; - } - size = remain; #endif - for (size_t i = 0; i < size; ++i) { - float value = std::abs(x[i]); - if (value > max_abs) { - max_abs = value; - } - } - return max_abs; + +template +inline int8_t Round(const float &x) { + return static_cast(x); } -#ifdef __aarch64__ -static void quantize_round_to_even(const Tensor *input, const float scale, - Tensor *output) { - const float *x = input->data(); - int8_t *y = output->mutable_data(); - size_t size = input->numel(); -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - size_t loop = size >> 4; - size_t remain = size & 0xF; +template <> +inline int8_t Round(const float &x) { + return std::round(x); +} - #pragma omp parallel for - for (size_t i = 0; i < loop; ++i) { - const float *local_x = x + (i << 4); - int8_t *local_y = y + (i << 4); - float32x4_t r0 = vld1q_f32(local_x); - float32x4_t r1 = vld1q_f32(local_x + 4); - float32x4_t r2 = vld1q_f32(local_x + 8); - float32x4_t r3 = vld1q_f32(local_x + 12); - r0 = vmulq_n_f32(r0, scale); - r1 = vmulq_n_f32(r1, scale); - r2 = vmulq_n_f32(r2, scale); - r3 = vmulq_n_f32(r3, scale); - int32x4_t q0 = vrnd_to_even(r0); - int32x4_t q1 = vrnd_to_even(r1); - int32x4_t q2 = vrnd_to_even(r2); - int32x4_t q3 = vrnd_to_even(r3); - int16x4_t d0 = vmovn_s32(q0); - int16x4_t d1 = vmovn_s32(q1); - int16x4_t d2 = vmovn_s32(q2); - int16x4_t d3 = vmovn_s32(q3); - int16x8_t q5 = vcombine_s16(d0, d1); - int16x8_t q6 = vcombine_s16(d2, d3); - int8x8_t d5 = vmovn_s16(q5); - int8x8_t d6 = vmovn_s16(q6); - vst1_s8(local_y, d5); - vst1_s8(local_y + 8, d6); - } - size = remain; - x += (loop << 4); - y += (loop << 4); -#endif - for (size_t i = 0; i < size; ++i) { - float value = x[i] * scale; - float v = round(value); - int32_t q = (int32_t)v; - if (abs(abs(q - value) - 0.5) > 0) { - y[i] = q; - } else { - if (abs(q) % 2 == 0) { - y[i] = q; - } else { - y[i] = q + ((q > 0) ? -1 : 1); - } +template <> +inline int8_t Round(const float &x) { + float v = std::round(x); + int32_t q = static_cast(v); + if (std::abs(std::abs(q - v) - 0.5) <= 0) { + if (std::abs(q) % 2 != 0) { + q = q + ((q > 0) ? -1 : 1); } } + return static_cast(q); } -static void quantize_round_to_zero(const Tensor *input, const float scale, - Tensor *output) { +template +static void Quantize(const Tensor *input, const float scale, Tensor *output) { const float *x = input->data(); int8_t *y = output->mutable_data(); - size_t size = input->numel(); + size_t remain = input->numel(); #if defined(__ARM_NEON__) || defined(__ARM_NEON) - size_t loop = size >> 4; - size_t remain = size & 0xF; + size_t loop = remain >> 4; + remain = remain & 0xF; #pragma omp parallel for for (size_t i = 0; i < loop; ++i) { @@ -204,10 +117,10 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, r1 = vmulq_n_f32(r1, scale); r2 = vmulq_n_f32(r2, scale); r3 = vmulq_n_f32(r3, scale); - int32x4_t q0 = vrnd_towards_zero(r0); - int32x4_t q1 = vrnd_towards_zero(r1); - int32x4_t q2 = vrnd_towards_zero(r2); - int32x4_t q3 = vrnd_towards_zero(r3); + int32x4_t q0 = vround_f32(r0); + int32x4_t q1 = vround_f32(r1); + int32x4_t q2 = vround_f32(r2); + int32x4_t q3 = vround_f32(r3); int16x4_t d0 = vmovn_s32(q0); int16x4_t d1 = vmovn_s32(q1); int16x4_t d2 = vmovn_s32(q2); @@ -219,561 +132,44 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, vst1_s8(local_y, d5); vst1_s8(local_y + 8, d6); } - size = remain; x += (loop << 4); y += (loop << 4); #endif - for (size_t i = 0; i < size; ++i) { - y[i] = static_cast(x[i] * scale); + for (size_t i = 0; i < remain; ++i) { + y[i] = Round(x[i] * scale); } } -static void quantize_round_to_nearest(const Tensor *input, const float scale, - Tensor *output) { +float find_abs_max(const Tensor *input) { + float max_abs = 0.f; const float *x = input->data(); - int8_t *y = output->mutable_data(); - size_t size = input->numel(); + size_t remain = input->numel(); #if defined(__ARM_NEON__) || defined(__ARM_NEON) - size_t loop = size >> 4; - size_t remain = size & 0xF; + size_t loop = remain >> 4; + remain = remain & 0xF; + float32x4_t __max = {0.f, 0.f, 0.f, 0.f}; - #pragma omp parallel for - for (size_t i = 0; i < loop; ++i) { - const float *local_x = x + (i << 4); - int8_t *local_y = y + (i << 4); - float32x4_t r0 = vld1q_f32(local_x); - float32x4_t r1 = vld1q_f32(local_x + 4); - float32x4_t r2 = vld1q_f32(local_x + 8); - float32x4_t r3 = vld1q_f32(local_x + 12); - r0 = vmulq_n_f32(r0, scale); - r1 = vmulq_n_f32(r1, scale); - r2 = vmulq_n_f32(r2, scale); - r3 = vmulq_n_f32(r3, scale); - int32x4_t q0 = vrnd_away_zero(r0); - int32x4_t q1 = vrnd_away_zero(r1); - int32x4_t q2 = vrnd_away_zero(r2); - int32x4_t q3 = vrnd_away_zero(r3); - int16x4_t d0 = vmovn_s32(q0); - int16x4_t d1 = vmovn_s32(q1); - int16x4_t d2 = vmovn_s32(q2); - int16x4_t d3 = vmovn_s32(q3); - int16x8_t q5 = vcombine_s16(d0, d1); - int16x8_t q6 = vcombine_s16(d2, d3); - int8x8_t d5 = vmovn_s16(q5); - int8x8_t d6 = vmovn_s16(q6); - vst1_s8(local_y, d5); - vst1_s8(local_y + 8, d6); + for (size_t i = 0; i < loop; ++i, x += 16) { + float32x4_t r0 = vld1q_f32(x); + float32x4_t r1 = vld1q_f32(x + 4); + float32x4_t r2 = vld1q_f32(x + 8); + float32x4_t r3 = vld1q_f32(x + 12); + r0 = vabsq_f32(r0); + r1 = vabsq_f32(r1); + r2 = vabsq_f32(r2); + r3 = vabsq_f32(r3); + r0 = vmaxq_f32(r0, r1); + r1 = vmaxq_f32(r2, r3); + r0 = vmaxq_f32(r0, r1); + __max = vmaxq_f32(r0, __max); } - size = remain; - x += (loop << 4); - y += (loop << 4); + max_abs = vmaxvq_f32(__max); #endif - for (size_t i = 0; i < size; ++i) { - y[i] = round(x[i] * scale); - } -} -#else // __aarch64__ - -static void quantize_round_to_even(const Tensor *input, const float scale, - const std::vector &paddings, - const int8_t padding_val, Tensor *output) {} - -static void quantize_round_to_nearest(const Tensor *input, const float scale, - const std::vector &paddings, - const int8_t padding_val, - Tensor *output) {} - -static void quantize_round_to_zero(const Tensor *input, const float scale, - const std::vector &paddings, - const int8_t padding_val, Tensor *output) { - int channels = input->dims()[1]; - int input_h = input->dims()[2]; - int input_w = input->dims()[3]; - int output_h = output->dims()[2]; - int output_w = output->dims()[3]; - int input_spatial_size = input_h * input_w; - int output_spatial_size = output_h * output_w; - const float *x = input->data(); - int8_t *y = output->mutable_data(); - // valid area start - int start = paddings[0] * output_w + paddings[1]; - - for (int batch = 0; batch < input->dims()[0]; ++batch) { - #pragma omp parallel for - for (int c = 0; c < channels - 3; c += 4) { - const float *input0 = x + (batch * channels + c) * input_spatial_size; - const float *input1 = input0 + input_spatial_size; - const float *input2 = input1 + input_spatial_size; - const float *input3 = input2 + input_spatial_size; - size_t offset = (batch * channels + c) * output_spatial_size; - for (int h = 0; h < 2; ++h) { - int8_t *y0 = - y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]); - int8_t *y1 = y0 + output_spatial_size; - int8_t *y2 = y1 + output_spatial_size; - int8_t *y3 = y2 + output_spatial_size; - int loop = start >> 4; - int remain = start & 0xF; - asm volatile( - "vdup.s8 q0, %[val] \n" - "cmp %[loop], #0 \n" - "ble start_remain_%= \n" - - "store_16w_%=: \n" - "vst1.32 {q0}, [%[y0]]! \n" - "vst1.32 {q0}, [%[y1]]! \n" - "vst1.32 {q0}, [%[y2]]! \n" - "vst1.32 {q0}, [%[y3]]! \n" - "subs %[loop], #1 \n" - "bne store_16w_%= \n" - - "start_remain_%=: \n" - "cmp %[remain], #8 \n" - "blt store_4w_%= \n" - "vst1.32 {d0}, [%[y0]]! \n" - "vst1.32 {d0}, [%[y1]]! \n" - "vst1.32 {d0}, [%[y2]]! \n" - "vst1.32 {d0}, [%[y3]]! \n" - "sub %[remain], #8 \n" - - "store_4w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_2w_%= \n" - "vst1.32 {d0[0]}, [%[y0]]! \n" - "vst1.32 {d0[0]}, [%[y1]]! \n" - "vst1.32 {d0[0]}, [%[y2]]! \n" - "vst1.32 {d0[0]}, [%[y3]]! \n" - "sub %[remain], #4 \n" - - "store_2w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_1w_%= \n" - "vst1.16 {d0[0]}, [%[y0]]! \n" - "vst1.16 {d0[0]}, [%[y1]]! \n" - "vst1.16 {d0[0]}, [%[y2]]! \n" - "vst1.16 {d0[0]}, [%[y3]]! \n" - "sub %[remain], #2 \n" - - "store_1w_%=: \n" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.8 {d0[0]}, [%[y0]]! \n" - "vst1.8 {d0[0]}, [%[y1]]! \n" - "vst1.8 {d0[0]}, [%[y2]]! \n" - "vst1.8 {d0[0]}, [%[y3]]! \n" - "end_%=: \n" - : [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), - [loop] "+r"(loop), [remain] "+r"(remain) - : [val] "r"(padding_val) - : "cc", "memory", "q0"); - } - // quantize valid area - int8_t *y0 = y + offset + start; - int8_t *y1 = y0 + output_spatial_size; - int8_t *y2 = y1 + output_spatial_size; - int8_t *y3 = y2 + output_spatial_size; - for (int h = 0; h < input_h; ++h) { - const float *x0 = input0 + h * input_w; - const float *x1 = input1 + h * input_w; - const float *x2 = input2 + h * input_w; - const float *x3 = input3 + h * input_w; - int loop = input_w >> 4; - int remain = input_w & 0xF; - int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2 - int pad_remain = (paddings[1] << 1) & 0x3; - int remain_steps = remain; - asm volatile( - "vdup.f32 q0, %[scale] \n" - "cmp %[loop], #0 \n" - "ble quantize_remain_%= \n" - - "loop_quantize_%=: \n" - "vld1.32 {q1, q2}, [%[x0]]! \n" - "vld1.32 {q3, q4}, [%[x1]]! \n" - "vld1.32 {q5, q6}, [%[x2]]! \n" - "vld1.32 {q7, q8}, [%[x3]]! \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vmul.f32 q3, q3, q0 \n" - "vmul.f32 q4, q4, q0 \n" - "vmul.f32 q5, q5, q0 \n" - "vmul.f32 q6, q6, q0 \n" - "vmul.f32 q7, q7, q0 \n" - "vmul.f32 q8, q8, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vcvt.s32.f32 q3, q3 \n" - "vcvt.s32.f32 q4, q4 \n" - "vcvt.s32.f32 q5, q5 \n" - "vcvt.s32.f32 q6, q6 \n" - "vcvt.s32.f32 q7, q7 \n" - "vcvt.s32.f32 q8, q8 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s32 d4, q3 \n" - "vmovn.s32 d5, q4 \n" - "vmovn.s32 d6, q5 \n" - "vmovn.s32 d7, q6 \n" - "vmovn.s32 d8, q7 \n" - "vmovn.s32 d9, q8 \n" - "vmovn.s16 d18, q1 \n" - "vmovn.s16 d20, q2 \n" - "vmovn.s16 d22, q3 \n" - "vmovn.s16 d24, q4 \n" - "vld1.32 {q1, q2}, [%[x0]]! \n" - "vld1.32 {q3, q4}, [%[x1]]! \n" - "vld1.32 {q5, q6}, [%[x2]]! \n" - "vld1.32 {q7, q8}, [%[x3]]! \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vmul.f32 q3, q3, q0 \n" - "vmul.f32 q4, q4, q0 \n" - "vmul.f32 q5, q5, q0 \n" - "vmul.f32 q6, q6, q0 \n" - "vmul.f32 q7, q7, q0 \n" - "vmul.f32 q8, q8, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vcvt.s32.f32 q3, q3 \n" - "vcvt.s32.f32 q4, q4 \n" - "vcvt.s32.f32 q5, q5 \n" - "vcvt.s32.f32 q6, q6 \n" - "vcvt.s32.f32 q7, q7 \n" - "vcvt.s32.f32 q8, q8 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s32 d4, q3 \n" - "vmovn.s32 d5, q4 \n" - "vmovn.s32 d6, q5 \n" - "vmovn.s32 d7, q6 \n" - "vmovn.s32 d8, q7 \n" - "vmovn.s32 d9, q8 \n" - "vmovn.s16 d19, q1 \n" - "vmovn.s16 d21, q2 \n" - "vmovn.s16 d23, q3 \n" - "vmovn.s16 d25, q4 \n" - "vst1.32 {q9}, [%[y0]]! \n" - "vst1.32 {q10}, [%[y1]]! \n" - "vst1.32 {q11}, [%[y2]]! \n" - "vst1.32 {q12}, [%[y3]]! \n" - - "subs %[loop], #1 \n" - "bne loop_quantize_%= \n" - - "quantize_remain_%=: \n" - "cmp %[remain], #0 \n" - "ble end_%= \n" - - "vld1.32 {q1, q2}, [%[x0]]! \n" - "vld1.32 {q3, q4}, [%[x1]]! \n" - "vld1.32 {q5, q6}, [%[x2]]! \n" - "vld1.32 {q7, q8}, [%[x3]]! \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vmul.f32 q3, q3, q0 \n" - "vmul.f32 q4, q4, q0 \n" - "vmul.f32 q5, q5, q0 \n" - "vmul.f32 q6, q6, q0 \n" - "vmul.f32 q7, q7, q0 \n" - "vmul.f32 q8, q8, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vcvt.s32.f32 q3, q3 \n" - "vcvt.s32.f32 q4, q4 \n" - "vcvt.s32.f32 q5, q5 \n" - "vcvt.s32.f32 q6, q6 \n" - "vcvt.s32.f32 q7, q7 \n" - "vcvt.s32.f32 q8, q8 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s32 d4, q3 \n" - "vmovn.s32 d5, q4 \n" - "vmovn.s32 d6, q5 \n" - "vmovn.s32 d7, q6 \n" - "vmovn.s32 d8, q7 \n" - "vmovn.s32 d9, q8 \n" - "vmovn.s16 d18, q1 \n" - "vmovn.s16 d20, q2 \n" - "vmovn.s16 d22, q3 \n" - "vmovn.s16 d24, q4 \n" - "vld1.32 {q1, q2}, [%[x0]] \n" - "vld1.32 {q3, q4}, [%[x1]] \n" - "vld1.32 {q5, q6}, [%[x2]] \n" - "vld1.32 {q7, q8}, [%[x3]] \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vmul.f32 q3, q3, q0 \n" - "vmul.f32 q4, q4, q0 \n" - "vmul.f32 q5, q5, q0 \n" - "vmul.f32 q6, q6, q0 \n" - "vmul.f32 q7, q7, q0 \n" - "vmul.f32 q8, q8, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vcvt.s32.f32 q3, q3 \n" - "vcvt.s32.f32 q4, q4 \n" - "vcvt.s32.f32 q5, q5 \n" - "vcvt.s32.f32 q6, q6 \n" - "vcvt.s32.f32 q7, q7 \n" - "vcvt.s32.f32 q8, q8 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s32 d4, q3 \n" - "vmovn.s32 d5, q4 \n" - "vmovn.s32 d6, q5 \n" - "vmovn.s32 d7, q6 \n" - "vmovn.s32 d8, q7 \n" - "vmovn.s32 d9, q8 \n" - "vmovn.s16 d19, q1 \n" - "vmovn.s16 d21, q2 \n" - "vmovn.s16 d23, q3 \n" - "vmovn.s16 d25, q4 \n" - - "cmp %[remain], #8 \n" - "blt store_4w_%= \n" - "vst1.32 {d18}, [%[y0]]! \n" - "vst1.32 {d20}, [%[y1]]! \n" - "vst1.32 {d22}, [%[y2]]! \n" - "vst1.32 {d24}, [%[y3]]! \n" - "vmov.32 d18, d19 \n" - "vmov.32 d20, d21 \n" - "vmov.32 d22, d23 \n" - "vmov.32 d24, d25 \n" - "sub %[remain], #8 \n" - - "store_4w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_2w_%= \n" - "vst1.32 {d18[0]}, [%[y0]]! \n" - "vst1.32 {d20[0]}, [%[y1]]! \n" - "vst1.32 {d22[0]}, [%[y2]]! \n" - "vst1.32 {d24[0]}, [%[y3]]! \n" - "vext.32 d18, d18, d18, #1 \n" - "vext.32 d20, d20, d20, #1 \n" - "vext.32 d22, d22, d22, #1 \n" - "vext.32 d24, d24, d24, #1 \n" - "sub %[remain], #4 \n" - - "store_2w_%=: \n" - "cmp %[remain], #2 \n" - "blt store_1w_%= \n" - "vst1.16 {d18[0]}, [%[y0]]! \n" - "vst1.16 {d20[0]}, [%[y1]]! \n" - "vst1.16 {d22[0]}, [%[y2]]! \n" - "vst1.16 {d24[0]}, [%[y3]]! \n" - "vext.16 d18, d18, d18, #1 \n" - "vext.16 d20, d20, d20, #1 \n" - "vext.16 d22, d22, d22, #1 \n" - "vext.16 d24, d24, d24, #1 \n" - "sub %[remain], #2 \n" - - "store_1w_%=:" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.8 {d18[0]}, [%[y0]]! \n" - "vst1.8 {d20[0]}, [%[y1]]! \n" - "vst1.8 {d22[0]}, [%[y2]]! \n" - "vst1.8 {d24[0]}, [%[y3]]! \n" - - "end_%=: \n" - : [x0] "+r"(x0), [x1] "+r"(x1), [x2] "+r"(x2), [x3] "+r"(x3), - [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), - [loop] "+r"(loop), [remain] "+r"(remain) - : [scale] "r"(scale) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12"); - asm volatile( - "vdup.s8 d0, %[val] \n" - "cmp %[pad_loop], #0 \n" - "ble store_pad_2w_%= \n" - "loop_pad_4w_%=: \n" - "vst1.32 {d0[0]}, [%[y0]]! \n" - "vst1.32 {d0[0]}, [%[y1]]! \n" - "vst1.32 {d0[0]}, [%[y2]]! \n" - "vst1.32 {d0[0]}, [%[y3]]! \n" - "subs %[pad_loop], #1 \n" - "bne loop_pad_4w_%= \n" - - "store_pad_2w_%=: \n" - "cmp %[pad_remain], #2 \n" - "blt store_pad_1w_%= \n" - "vst1.16 {d0[0]}, [%[y0]]! \n" - "vst1.16 {d0[0]}, [%[y1]]! \n" - "vst1.16 {d0[0]}, [%[y2]]! \n" - "vst1.16 {d0[0]}, [%[y3]]! \n" - "sub %[pad_remain], #2 \n" - - "store_pad_1w_%=: \n" - "cmp %[pad_remain], #1 \n" - "blt end_%= \n" - "vst1.8 {d0[0]}, [%[y0]]! \n" - "vst1.8 {d0[0]}, [%[y1]]! \n" - "vst1.8 {d0[0]}, [%[y2]]! \n" - "vst1.8 {d0[0]}, [%[y3]]! \n" - "end_%=: \n" - : [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), - [pad_loop] "+r"(pad_loop), [pad_remain] "+r"(pad_remain) - : [val] "r"(padding_val) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12"); - } - } - for (int c = (channels & 0xFFFC); c < channels; ++c) { - const float *input0 = x + (batch * channels + c) * input_spatial_size; - size_t offset = (batch * channels + c) * output_spatial_size; - for (int h = 0; h < 2; ++h) { - int8_t *y0 = - y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]); - int loop = start >> 4; - int remain = start & 0xF; - asm volatile( - "vdup.s8 q0, %[val] \n" - "cmp %[loop], #0 \n" - "ble start_remain_%= \n" - - "store_16w_%=: \n" - "vst1.32 {q0}, [%[y0]]! \n" - "subs %[loop], #1 \n" - "bne store_16w_%= \n" - - "start_remain_%=: \n" - "cmp %[remain], #8 \n" - "blt store_4w_%= \n" - "vst1.32 {d0}, [%[y0]]! \n" - "sub %[remain], #8 \n" - - "store_4w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_2w_%= \n" - "vst1.32 {d0[0]}, [%[y0]]! \n" - "sub %[remain], #4 \n" - - "store_2w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_1w_%= \n" - "vst1.16 {d0[0]}, [%[y0]]! \n" - "sub %[remain], #2 \n" - - "store_1w_%=: \n" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.8 {d0[0]}, [%[y0]]! \n" - "end_%=: \n" - : [y0] "+r"(y0), [loop] "+r"(loop), [remain] "+r"(remain) - : [val] "r"(padding_val) - : "cc", "memory", "q0"); - } - // quantize valid area - int8_t *y0 = y + offset + start; - for (int h = 0; h < input_h; ++h) { - const float *x0 = input0 + h * input_w; - int loop = input_w >> 4; - int remain = input_w & 0xF; - int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2 - int pad_remain = (paddings[1] << 1) & 0x3; - asm volatile( - "vdup.f32 q0, %[scale] \n" - "cmp %[loop], #0 \n" - "ble quantize_remain_%= \n" - - "loop_quantize_%=: \n" - "vld1.32 {q1, q2}, [%[x0]]! \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s16 d18, q1 \n" - "vld1.32 {q1, q2}, [%[x0]]! \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s16 d19, q1 \n" - "vst1.32 {q9}, [%[y0]]! \n" - - "subs %[loop], #1 \n" - "bne loop_quantize_%= \n" - - "quantize_remain_%=: \n" - "cmp %[remain], #0 \n" - "ble start_pad_%= \n" - - "vldm %[x0], {d2-d9} \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s16 d18, q1 \n" - "vmul.f32 q3, q3, q0 \n" - "vmul.f32 q4, q4, q0 \n" - "vcvt.s32.f32 q1, q3 \n" - "vcvt.s32.f32 q2, q4 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s16 d19, q1 \n" - - "cmp %[remain], #8 \n" - "blt store_4w_%= \n" - "vst1.32 {d18}, [%[y0]]! \n" - "vmov.32 d18, d19 \n" - "sub %[remain], #8 \n" - - "store_4w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_2w_%= \n" - "vst1.32 {d18[0]}, [%[y0]]! \n" - "vext.32 d18, d18, d18, #1 \n" - "sub %[remain], #4 \n" - - "store_2w_%=: \n" - "cmp %[remain], #2 \n" - "blt store_1w_%= \n" - "vst1.16 {d18[0]}, [%[y0]]! \n" - "vext.16 d18, d18, d18, #1 \n" - "sub %[remain], #2 \n" - - "store_1w_%=:" - "cmp %[remain], #1 \n" - "blt start_pad_%= \n" - "vst1.8 {d18[0]}, [%[y0]]! \n" - - "start_pad_%=: \n" - "vdup.s8 d0, %[val] \n" - "cmp %[pad_loop], #0 \n" - "ble pad_remain_%= \n" - "loop_pad_4w_%=: \n" - "vst1.32 {d0[0]}, [%[y0]]! \n" - "subs %[pad_loop], #1 \n" - "bne loop_pad_4w_%= \n" - - "pad_remain_%=: \n" - "cmp %[pad_remain], #2 \n" - "blt store_pad_1w_%= \n" - "vst1.16 {d0[0]}, [%[y0]]! \n" - "sub %[pad_remain], #2 \n" - - "store_pad_1w_%=: \n" - "cmp %[pad_remain], #1 \n" - "blt end_%= \n" - "vst1.8 {d0[0]}, [%[y0]]! \n" - "end_%=: \n" - : [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop), - [remain] "+r"(remain), [pad_loop] "+r"(pad_loop), - [pad_remain] "+r"(pad_remain) - : [scale] "r"(scale), [val] "r"(padding_val) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q9"); - } - } + for (size_t i = 0; i < remain; ++i) { + max_abs = std::max(max_abs, std::abs(x[i])); } + return max_abs; } -#endif // __aarch64__ -#endif // ARM_NEON template <> bool QuantizeKernel::Init(QuantizeParam *param) { @@ -795,19 +191,15 @@ void QuantizeKernel::Compute(const QuantizeParam ¶m) { // only support int8 currently float scale = 127 / max_abs; param.online_scale_->mutable_data()[0] = max_abs; - const auto &paddings = param.paddings_; - // std::vector paddings = {0, 0}; - // const auto padding_val = param.padding_val_; - int8_t padding_val = 0; switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: - quantize_round_to_even(input, scale, paddings, padding_val, output); + Quantize(input, scale, output); break; case ROUND_NEAREST_TOWARDS_ZERO: - quantize_round_to_zero(input, scale, paddings, padding_val, output); + Quantize(input, scale, output); break; case ROUND_NEAREST_AWAY_ZERO: - quantize_round_to_nearest(input, scale, paddings, padding_val, output); + Quantize(input, scale, output); break; default: LOG(kLOG_ERROR) << "round type is not supported."; diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index 3b5924ecbf886159d129212cc36c8630cb8cce2f..988f0b0f03b84c25a2e17e9d14054f99dcce4916 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -132,10 +132,10 @@ void ConvAddCompute(const FusionConvAddParam ¶m) { // param.Output(), false); if (param.Paddings()[0] == 0) { math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - *param.Bias(), true); + param.Bias(), true); } else { math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), *param.Bias(), true); + param.Output(), param.Bias(), true); } } else { ConvAddBasic(param); diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index b01a654c713f2328d62714f23af68d606380d203..7f40157c30ad19472045eb53bd7a99e577429db5 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -164,31 +164,21 @@ template inline void DepthwiseConv3x3(const ConvParam ¶m) { const Tensor *input = param.Input(); const Tensor *filter = param.Filter(); + const std::vector &paddings = param.Paddings(); + const std::vector &strides = param.Strides(); + const int batch_size = input->dims()[0]; Tensor *output = param.Output(); output->mutable_data(); - const std::vector &paddings = param.Paddings(); - const std::vector &strides = param.Strides(); - const int batch_size = static_cast(input->dims()[0]); - Tensor input_pad; - math::PadFunctor pad; for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1); Tensor out_batch = output->Slice(i, i + 1); - if (paddings[0] || paddings[1]) { - framework::DDim pad_shape = in_batch.dims(); - pad_shape[2] += 2 * paddings[0]; - pad_shape[3] += 2 * paddings[1]; - input_pad.mutable_data(pad_shape); - pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1], - &input_pad); - } else { - input_pad = in_batch; - } if (strides[0] == 1) { - math::DepthwiseConv3x3s1(input_pad, *filter, &out_batch); + math::DepthwiseConv3x3S1(in_batch, *filter, paddings, + &out_batch); } else if (strides[0] == 2) { - math::DepthwiseConv3x3s2(input_pad, *filter, &out_batch); + math::DepthwiseConv3x3S2(in_batch, *filter, paddings, + &out_batch); } else { // math::DepthwiseConv3x3(input_pad, *filter, // &out_batch); diff --git a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h deleted file mode 100644 index b48b03491bab9594f36cad0b21485ae72c8c3c31..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef DEPTHWISECONV_OP - -#pragma once -#include -#include "operators/kernel/central-arm-func/conv_arm_func.h" -#include "operators/math/depthwise_conv3x3.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -template -void DepthwiseConvCompute(const ConvParam ¶m) { - Tensor Bias; - Bias.mutable_data({param.Groups()}); - if (param.Groups() == param.Input()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - &Bias, false); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - // math::DepthwiseConv3x3(param.Input(), param.Strides(), - // param.Paddings(), - // param.Filter(), &Bias, param.Output(), false); - math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), - Bias, false); - - } else { - GemmConv(param); - } -} - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/dequant_add_bn_relu_kernel.h b/src/operators/kernel/dequant_add_bn_kernel.h similarity index 75% rename from src/operators/kernel/dequant_add_bn_relu_kernel.h rename to src/operators/kernel/dequant_add_bn_kernel.h index 7138e5c415caca6766913f9959bd41def0943d34..2fcdad6903e378121c265080f68c35c451714e30 100644 --- a/src/operators/kernel/dequant_add_bn_relu_kernel.h +++ b/src/operators/kernel/dequant_add_bn_kernel.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +#ifdef FUSION_DEQUANT_ADD_BN_OP #include "framework/operator.h" #include "operators/op_param.h" @@ -23,12 +23,12 @@ namespace paddle_mobile { namespace operators { template -class FusionDequantAddBNReluKernel +class FusionDequantAddBNKernel : public framework::OpKernelBase> { + FusionDequantAddBNParam> { public: - void Compute(const FusionDequantAddBNReluParam ¶m); - bool Init(FusionDequantAddBNReluParam *param); + void Compute(const FusionDequantAddBNParam ¶m); + bool Init(FusionDequantAddBNParam *param); }; } // namespace operators diff --git a/src/operators/kernel/dequant_bn_relu_kernel.h b/src/operators/kernel/dequant_bn_relu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..edea449dd68db474b14b02304bbdf63768e1bfb0 --- /dev/null +++ b/src/operators/kernel/dequant_bn_relu_kernel.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +#ifdef FUSION_DEQUANT_BN_RELU_OP +template +class FusionDequantBNReluKernel + : public framework::OpKernelBase> { + public: + void Compute(const FusionDequantBNReluParam ¶m); + bool Init(FusionDequantBNReluParam *param); +}; +#endif + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +template +class FusionDequantAddBNReluKernel + : public framework::OpKernelBase> { + public: + void Compute(const FusionDequantAddBNReluParam ¶m); + bool Init(FusionDequantAddBNReluParam *param); +}; +#endif + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp index 39b9b8d3f1c5c2bf09a3db5de5216dd1a08b491a..a4466a52fac228812e8df205a61bdb594775d327 100644 --- a/src/operators/math/depthwise_conv3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -1272,13 +1272,16 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor bias, + framework::Tensor *output, framework::Tensor *bias, bool if_bias) { #if __ARM_NEON const float *input_data = input->data(); const float *filter_data = filter->data(); float *output_data = output->data(); - const float *bias_data = bias.data(); + const float *bias_data; + if (if_bias) { + bias_data = bias->data(); + } const int in_h = static_cast(input->dims()[2]); const int in_w = static_cast(input->dims()[3]); @@ -1905,7 +1908,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, void DepthwiseConv3x3s2p0(const framework::Tensor *input, const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor bias, + framework::Tensor *output, framework::Tensor *bias, bool if_bias) { #if __ARM_NEON @@ -1925,7 +1928,7 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, for (int c = 0; c < input_channel; c++) { const float *filter_data = filter->data() + c * 9; const float *input_data = input->data() + c * inhxw; - const float *bias_data = bias.data() + c; + const float *bias_data = bias->data() + c; float *output_data = output->data() + c * outhxw; float w00 = filter_data[0]; float w01 = filter_data[1]; diff --git a/src/operators/math/depthwise_conv3x3.h b/src/operators/math/depthwise_conv3x3.h index 72cadaf21553a428e1479d5548d2aa5f4fcdf90c..ca8f45fa5186fc1a2642a53f27526c6898bfb8e3 100644 --- a/src/operators/math/depthwise_conv3x3.h +++ b/src/operators/math/depthwise_conv3x3.h @@ -50,7 +50,7 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor bias, + framework::Tensor *output, framework::Tensor *bias, bool if_bias); void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, @@ -62,7 +62,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, void DepthwiseConv3x3s2p0(const framework::Tensor *input, const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor bias, + framework::Tensor *output, framework::Tensor *bias, bool if_bias); // TODO(hjchen2) need to be implemented @@ -70,16 +70,19 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, // void DepthwiseConv3x3(const framework::Tensor *input, // const framework::Tensor *filter, // const std::vector &strides, +// const std::vector &paddings, // framework::Tensor *output); template -void DepthwiseConv3x3s1(const framework::Tensor &input, +void DepthwiseConv3x3S1(const framework::Tensor &input, const framework::Tensor &filter, + const std::vector &paddings, framework::Tensor *output); template -void DepthwiseConv3x3s2(const framework::Tensor &input, +void DepthwiseConv3x3S2(const framework::Tensor &input, const framework::Tensor &filter, + const std::vector &paddings, framework::Tensor *output); } // namespace math diff --git a/src/operators/math/depthwise_conv3x3_int8.cpp b/src/operators/math/depthwise_conv3x3_int8.cpp index ddd8f79f7ce350e048585917f96d82639d4ea951..9b4c6096ecdbd7adee27728ebaae47149392dad9 100644 --- a/src/operators/math/depthwise_conv3x3_int8.cpp +++ b/src/operators/math/depthwise_conv3x3_int8.cpp @@ -12,12 +12,300 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#if defined(__ARM_NEON__) && !defined(__aarch64__) + #include "operators/math/depthwise_conv3x3.h" +#ifdef __ARM_NEON__ +#include +#endif namespace paddle_mobile { namespace operators { namespace math { +template +inline void Depth3x3ValidColLoadInput(const int8_t *input, const int input_w, + const int valid_cols, int16x8_t *y0, + int16x8_t *y1, int16x8_t *y2) { + PADDLE_MOBILE_THROW_EXCEPTION("Stride %d is not supported.", Stride); +} + +template <> +inline void Depth3x3ValidColLoadInput<1>(const int8_t *input, const int input_w, + const int valid_cols, int16x8_t *y0, + int16x8_t *y1, int16x8_t *y2) { + int8_t fake_input[3][8]; + if (valid_cols == 1) { + for (int i = 0; i < 8; ++i, input += input_w) { + fake_input[0][i] = input[0]; + } + } else if (valid_cols == 2) { + for (int i = 0; i < 8; ++i, input += input_w) { + fake_input[0][i] = input[0]; + fake_input[1][i] = input[1]; + } + } else { + for (int i = 0; i < 8; ++i, input += input_w) { + fake_input[0][i] = input[0]; + fake_input[1][i] = input[1]; + fake_input[2][i] = input[2]; + } + } + int8x8_t input0 = vld1_s8(fake_input[0]); + int8x8_t input1 = vld1_s8(fake_input[1]); + int8x8_t input2 = vld1_s8(fake_input[2]); + y0[0] = vmovl_s8(input0); + y1[0] = vmovl_s8(input1); + y2[0] = vmovl_s8(input2); + y0[1] = vextq_s16(y0[0], y0[0], 1); + y0[2] = vextq_s16(y0[0], y0[0], 2); + y1[1] = vextq_s16(y1[0], y1[0], 1); + y1[2] = vextq_s16(y1[0], y1[0], 2); + y2[1] = vextq_s16(y2[0], y2[0], 1); + y2[2] = vextq_s16(y2[0], y2[0], 2); +} + +template <> +inline void Depth3x3ValidColLoadInput<2>(const int8_t *input, const int input_w, + const int valid_cols, int16x8_t *y0, + int16x8_t *y1, int16x8_t *y2) { + int8_t fake_input[3][13]; + if (valid_cols == 1) { + for (int i = 0; i < 13; ++i, input += input_w) { + fake_input[0][i] = input[0]; + } + } else if (valid_cols == 2) { + for (int i = 0; i < 13; ++i, input += input_w) { + fake_input[0][i] = input[0]; + fake_input[1][i] = input[1]; + } + } else { + for (int i = 0; i < 13; ++i, input += input_w) { + fake_input[0][i] = input[0]; + fake_input[1][i] = input[1]; + fake_input[2][i] = input[2]; + } + } + int8x8x2_t input0 = vld2_s8(fake_input[0]); + int8x8x2_t input1 = vld2_s8(fake_input[1]); + int8x8x2_t input2 = vld2_s8(fake_input[2]); + y0[0] = vmovl_s8(input0.val[0]); + y0[1] = vmovl_s8(input0.val[1]); + y0[2] = vextq_s16(y0[0], y0[0], 1); + y1[0] = vmovl_s8(input1.val[0]); + y1[1] = vmovl_s8(input1.val[1]); + y1[2] = vextq_s16(y1[0], y1[0], 1); + y2[0] = vmovl_s8(input2.val[0]); + y2[1] = vmovl_s8(input2.val[1]); + y2[2] = vextq_s16(y2[0], y2[0], 1); +} + +template +inline void DepthwiseConv3x3ValidCol(const int8_t *input, const int8_t *filter, + const int h_output, const int h_output_end, + const int w_output, const int input_h, + const int input_w, const int padding_h, + const int padding_w, const int output_w, + int32_t *output) { + const int w_in_start = -padding_w + w_output * Stride_w; + const int w_in_end = w_in_start + 3; + const int w_start = w_in_start > 0 ? w_in_start : 0; + const int w_end = w_in_end < input_w ? w_in_end : input_w; + int remain_start = h_output; + +#ifdef __ARM_NEON__ + int output_tiles = (h_output_end - h_output) / 6; + remain_start = h_output + output_tiles * 6; + int input_h_start = h_output * Stride_h - padding_h; + size_t input_offset = input_h_start * input_w + w_start; + size_t output_offset = h_output * output_w + w_output; + int16x8_t _input[3][3]; + int16x4_t _kernel[3]; + int32x4_t _sum0, _sum1; + const int8_t *filter_ptr = filter; + asm volatile( + "mov r0, #3 \n" + "vld1.s8 d10, [%[filter]], r0 \n" + "vld1.s8 d11, [%[filter]], r0 \n" + "vld1.s8 d12, [%[filter]] \n" + "vtrn.8 d10, d11 \n" + "vtrn.8 d12, d13 \n" + "vtrn.16 d10, d12 \n" + "vtrn.16 d11, d13 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d11 \n" + "vmovl.s8 q9, d12 \n" + "vmov.32 %[_kernel0], d14 \n" + "vmov.32 %[_kernel1], d16 \n" + "vmov.32 %[_kernel2], d18 \n" + : [_kernel0] "+w"(_kernel[0]), [_kernel1] "+w"(_kernel[1]), + [_kernel2] "+w"(_kernel[2]) + : [filter] "r"(filter_ptr) + : "memory", "q5", "q6", "q7", "q8", "q9", "r0"); + int valid_cols = w_end - w_start; + for (int h = 0; h < output_tiles * 6; h += 6) { + int32_t *output0 = output + output_offset; + int32_t *output1 = output0 + output_w; + int32_t *output2 = output1 + output_w; + int32_t *output3 = output2 + output_w; + int32_t *output4 = output3 + output_w; + int32_t *output5 = output4 + output_w; + Depth3x3ValidColLoadInput(input + input_offset, input_w, + valid_cols, _input[0], _input[1], + _input[2]); + _sum0 = veorq_s32(_sum0, _sum0); + _sum1 = veorq_s32(_sum1, _sum1); + for (int w_in = 0; w_in < valid_cols; ++w_in) { + int index = w_in + w_start - w_in_start; + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][0]), + _kernel[index], 0); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][1]), + _kernel[index], 1); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][2]), + _kernel[index], 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][0]), + _kernel[index], 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][1]), + _kernel[index], 1); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][2]), + _kernel[index], 2); + } + vst1q_lane_s32(output0, _sum0, 0); + vst1q_lane_s32(output1, _sum0, 1); + vst1q_lane_s32(output2, _sum0, 2); + vst1q_lane_s32(output3, _sum0, 3); + vst1q_lane_s32(output4, _sum1, 0); + vst1q_lane_s32(output5, _sum1, 1); + input_offset += 6 * Stride_h * input_w; + output_offset += 6 * output_w; + } +#endif + for (int h = remain_start; h < h_output_end; ++h) { + int32_t value = 0; + const int h_in_start = -padding_h + h * Stride_h; + for (int i = 0; i < 3; ++i) { + for (int w_in = w_start; w_in < w_end; ++w_in) { + value += filter[i * 3 + (w_in - w_in_start)] * + input[(h_in_start + i) * input_w + w_in]; + } + } + output[h * output_w + w_output] = value; + } +} + +#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \ + for (int w = start; w < end; ++w) { \ + const int w_in_start = -padding_w + w * Stride_w; \ + const int w_in_end = w_in_start + 3; \ + const int w_start = w_in_start > 0 ? w_in_start : 0; \ + const int w_end = w_in_end < input_w ? w_in_end : input_w; \ + int32_t value = 0; \ + for (int h_in = h_start; h_in < h_end; ++h_in) { \ + for (int w_in = w_start; w_in < w_end; ++w_in) { \ + value += filter[(h_in - h_in_start) * 3 + (w_in - w_in_start)] * \ + input[h_in * input_w + w_in]; \ + } \ + } \ + output_ptr[w] = value; \ + } + +template +inline void Depth3x3NormalRowLoadInput(const int8_t *input, + int16x8_t &y0, // NOLINT + int16x8_t &y1, // NOLINT + int16x8_t &y2) { // NOLINT + PADDLE_MOBILE_THROW_EXCEPTION("Stride %d is not supported.", Stride); +} + +template <> +inline void Depth3x3NormalRowLoadInput<1>(const int8_t *input, + int16x8_t &y0, // NOLINT + int16x8_t &y1, // NOLINT + int16x8_t &y2) { // NOLINT + int8x8_t x0 = vld1_s8(input); + y0 = vmovl_s8(x0); + y1 = vextq_s16(y0, y0, 1); + y2 = vextq_s16(y1, y1, 1); +} + +template <> +inline void Depth3x3NormalRowLoadInput<2>(const int8_t *input, + int16x8_t &y0, // NOLINT + int16x8_t &y1, // NOLINT + int16x8_t &y2) { // NOLINT + int8x8x2_t x0 = vld2_s8(input); + y0 = vmovl_s8(x0.val[0]); + y1 = vmovl_s8(x0.val[1]); + y2 = vextq_s16(y0, y0, 1); +} + +template +inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, + const int h_output, const int input_h, + const int input_w, const int padding_h, + const int padding_w, const int output_w, + int32_t *output) { + const int h_in_start = -padding_h + h_output * Stride_h; + const int h_in_end = h_in_start + 3; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int h_end = h_in_end < input_h ? h_in_end : input_h; + + int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; + int valid_w_end = output_w - valid_w_start; + + int32_t *output_ptr = output + h_output * output_w; + // border left + DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start) + // middle + int remain_start = valid_w_start; +#ifdef __ARM_NEON__ + int output_tiles = (valid_w_end - valid_w_start) / 6; + remain_start = valid_w_start + output_tiles * 6; + int32x4_t _sum0, _sum1; + int16x8_t y0, y1, y2; + int16x4_t _kernel[3]; + for (int h_in = h_start; h_in < h_end; ++h_in) { + int index = h_in - h_in_start; + int8x8_t w0 = vld1_s8(filter + index * 3); + int16x8_t w1 = vmovl_s8(w0); + _kernel[index] = vget_low_s16(w1); + } + for (int w = 0; w < output_tiles * 6; w += 6) { + _sum0 = veorq_s32(_sum0, _sum0); + _sum1 = veorq_s32(_sum1, _sum1); + int output_offset = valid_w_start + w; + int input_w_offset = output_offset * Stride_w - padding_w; + for (int h_in = h_start; h_in < h_end; ++h_in) { + int index = h_in - h_in_start; + Depth3x3NormalRowLoadInput( + input + h_in * input_w + input_w_offset, y0, y1, y2); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y0), _kernel[index], 0); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y1), _kernel[index], 1); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y2), _kernel[index], 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y0), _kernel[index], 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y1), _kernel[index], 1); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y2), _kernel[index], 2); + } + vst1q_s32(output_ptr + output_offset, _sum0); + vst1q_lane_s32(output_ptr + output_offset + 4, _sum1, 0); + vst1q_lane_s32(output_ptr + output_offset + 5, _sum1, 1); + } +#endif + for (int w = remain_start; w < valid_w_end; ++w) { + int32_t value = 0; + int input_start = -padding_w + w * Stride_w; + for (int h_in = h_start; h_in < h_end; ++h_in) { + for (int j = 0; j < 3; ++j) { + value += filter[(h_in - h_in_start) * 3 + j] * + input[h_in * input_w + j + input_start]; + } + } + output_ptr[w] = value; + } + // border right + DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w) +} + // template<> // void DepthwiseConv3x3( // const framework::Tensor *input, const framework::Tensor *filter, @@ -27,43 +315,72 @@ namespace math { // } template <> -void DepthwiseConv3x3s1(const framework::Tensor &input, +void DepthwiseConv3x3S1(const framework::Tensor &input, const framework::Tensor &filter, + const std::vector &paddings, framework::Tensor *output) { const int8_t *input_data = input.data(); const int8_t *filter_data = filter.data(); int32_t *out_data = output->mutable_data(); - // make sure that batch size is 1 - int input_c = input.dims()[1]; int input_h = input.dims()[2]; int input_w = input.dims()[3]; - int output_c = output->dims()[1]; int output_h = output->dims()[2]; int output_w = output->dims()[3]; + int padding_h = paddings[0]; + int padding_w = paddings[1]; int image_size = input_h * input_w; int out_image_size = output_h * output_w; -#if __aarch64__ - // TODO(hjchen2) -#else + int valid_h_start = padding_h; + int valid_h_end = output_h - valid_h_start; + int valid_h = valid_h_end - valid_h_start; + int valid_w_start = padding_w; + int valid_w_end = output_w - valid_w_start; + int valid_w = valid_w_end - valid_w_start; + #pragma omp parallel for - for (int g = 0; g < input_c; ++g) { - const int8_t* input_ptr = input_data + g * image_size; - const int8_t* filter_ptr = filter_data + g * 9; - int32_t* output_ptr = out_data + g * out_image_size; - int loops = (input_w - 2) / 6; - int remain = input_w - 2 - loops * 6; - for (int h = 0; h < input_h - 5 /*(input_h - 2) - 3*/; h += 4) { - const int8_t* input_ptr0 = input_ptr + h * input_w; - const int8_t* input_ptr1 = input_ptr0 + input_w; - const int8_t* input_ptr2 = input_ptr1 + input_w; - const int8_t* input_ptr3 = input_ptr2 + input_w; - const int8_t* input_ptr4 = input_ptr3 + input_w; - const int8_t* input_ptr5 = input_ptr4 + input_w; - int32_t* output_ptr0 = output_ptr + h * output_w; - int32_t* output_ptr1 = output_ptr0 + output_w; - int32_t* output_ptr2 = output_ptr1 + output_w; - int32_t* output_ptr3 = output_ptr2 + output_w; - int loop = loops; + for (int g = 0; g < input.dims()[1]; ++g) { + const int8_t *input_ptr = input_data + g * image_size; + const int8_t *filter_ptr = filter_data + g * 9; + int32_t *output_ptr = out_data + g * out_image_size; + // top + for (int h = 0; h < valid_h_start; ++h) { + DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr); + } + // left + for (int w = 0; w < valid_w_start; ++w) { + DepthwiseConv3x3ValidCol<1, 1>( + input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h, + input_w, padding_h, padding_w, output_w, output_ptr); + } + // right + for (int w = valid_w_end; w < output_w; ++w) { + DepthwiseConv3x3ValidCol<1, 1>( + input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h, + input_w, padding_h, padding_w, output_w, output_ptr); + } + // bottom + for (int h = valid_h_end; h < output_h; ++h) { + DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr); + } + // valid + int output_w_tiles = valid_w / 6; + int output_w_remain = valid_w - output_w_tiles * 6; + for (int h = valid_h_start; h < valid_h_end - 3; h += 4) { + const int8_t *input_ptr0 = input_ptr + (h - padding_h) * input_w; + const int8_t *input_ptr1 = input_ptr0 + input_w; + const int8_t *input_ptr2 = input_ptr1 + input_w; + const int8_t *input_ptr3 = input_ptr2 + input_w; + const int8_t *input_ptr4 = input_ptr3 + input_w; + const int8_t *input_ptr5 = input_ptr4 + input_w; + int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; + int32_t *output_ptr1 = output_ptr0 + output_w; + int32_t *output_ptr2 = output_ptr1 + output_w; + int32_t *output_ptr3 = output_ptr2 + output_w; + int loop = output_w_tiles; asm volatile( "vld1.32 {q0}, [%[filter_ptr]] \n" "vmovl.s8 q14, d0 \n" @@ -377,27 +694,27 @@ void DepthwiseConv3x3s1(const framework::Tensor &input, "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" "vst1.32 {d28[0]}, [%[output_ptr2]]! \n" "vst1.32 {d10[0]}, [%[output_ptr3]]! \n" - "end_%=: \n" + "end_%=: \n" : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), [output_ptr2] "+r"(output_ptr2), [output_ptr3] "+r"(output_ptr3), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5), [loop] "+r"(loop) - : [remain] "r"(remain) + : [remain] "r"(output_w_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); } // remain height - int start_h = (input_h - 2) & 0xFFFC; - for (int h = start_h; h < input_h - 3 /*(input_h - 2) - 1*/; h += 2) { - const int8_t* input_ptr0 = input_ptr + h * input_w; - const int8_t* input_ptr1 = input_ptr0 + input_w; - const int8_t* input_ptr2 = input_ptr1 + input_w; - const int8_t* input_ptr3 = input_ptr2 + input_w; - int32_t* output_ptr0 = output_ptr + h * output_w; - int32_t* output_ptr1 = output_ptr0 + output_w; - int loop = loops; + int start_h = valid_h_start + (valid_h & 0xFFFC); + for (int h = start_h; h < valid_h_end - 1; h += 2) { + const int8_t *input_ptr0 = input_ptr + (h - padding_h) * input_w; + const int8_t *input_ptr1 = input_ptr0 + input_w; + const int8_t *input_ptr2 = input_ptr1 + input_w; + const int8_t *input_ptr3 = input_ptr2 + input_w; + int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; + int32_t *output_ptr1 = output_ptr0 + output_w; + int loop = output_w_tiles; asm volatile( "vld1.32 {q0}, [%[filter_ptr]] \n" "vmovl.s8 q14, d0 \n" @@ -415,9 +732,9 @@ void DepthwiseConv3x3s1(const framework::Tensor &input, : [filter_ptr] "r"(filter_ptr) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); asm volatile( - "mov r0, #6 \n" "cmp %[loop], #0 \n" "ble start_remain_%= \n" + "mov r0, #6 \n" // loop 6 widths "loop_2h6w_%=: \n" "vld1.32 {d9}, [%[input_ptr0]], r0 \n" @@ -589,23 +906,23 @@ void DepthwiseConv3x3s1(const framework::Tensor &input, "blt end_%= \n" "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" - "end_%=: \n" + "end_%=: \n" : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), [loop] "+r"(loop) - : [remain] "r"(remain) + : [remain] "r"(output_w_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "r0"); } - start_h = (input_h - 2) & 0xFFFE; - if (start_h < input_h - 2) { - const int8_t* input_ptr0 = input_ptr + start_h * input_w; - const int8_t* input_ptr1 = input_ptr0 + input_w; - const int8_t* input_ptr2 = input_ptr1 + input_w; - int32_t* output_ptr0 = output_ptr + start_h * output_w; - int loop = loops; + start_h = valid_h_start + (valid_h & 0xFFFE); + if (start_h < valid_h_end) { + const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; + const int8_t *input_ptr1 = input_ptr0 + input_w; + const int8_t *input_ptr2 = input_ptr1 + input_w; + int32_t *output_ptr0 = output_ptr + start_h * output_w + valid_w_start; + int loop = output_w_tiles; asm volatile( "vld1.32 {q0}, [%[filter_ptr]] \n" "vmovl.s8 q14, d0 \n" @@ -623,9 +940,9 @@ void DepthwiseConv3x3s1(const framework::Tensor &input, : [filter_ptr] "r"(filter_ptr) : "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); asm volatile( - "mov r0, #6 \n" "cmp %[loop], #0 \n" "ble start_remain_%= \n" + "mov r0, #6 \n" // loop 6 widths "loop_1h6w_%=: \n" "vld1.32 {d9}, [%[input_ptr0]], r0 \n" @@ -736,56 +1053,91 @@ void DepthwiseConv3x3s1(const framework::Tensor &input, "cmp %[remain], #1 \n" "blt end_%= \n" "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" - "end_%=: \n" + "end_%=: \n" : [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [loop] "+r"(loop) - : [remain] "r"(remain) + : [remain] "r"(output_w_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "r0"); } } -#endif // __aarch64__ } template <> -void DepthwiseConv3x3s2(const framework::Tensor &input, +void DepthwiseConv3x3S2(const framework::Tensor &input, const framework::Tensor &filter, + const std::vector &paddings, framework::Tensor *output) { const int8_t *input_data = input.data(); const int8_t *filter_data = filter.data(); int32_t *out_data = output->mutable_data(); - // make sure that batch size is 1 - int input_c = input.dims()[1]; int input_h = input.dims()[2]; int input_w = input.dims()[3]; - int output_c = output->dims()[1]; int output_h = output->dims()[2]; int output_w = output->dims()[3]; + int padding_h = paddings[0]; + int padding_w = paddings[1]; int image_size = input_h * input_w; int out_image_size = output_h * output_w; -#if __aarch64__ - // TODO(hjchen2) -#else + int valid_h_start = (padding_h + 1) / 2; + int valid_h_end = output_h - valid_h_start; + int valid_h = valid_h_end - valid_h_start; + int valid_w_start = (padding_w + 1) / 2; + int valid_w_end = output_w - valid_w_start; + int valid_w = valid_w_end - valid_w_start; + + // DLOG << "valid_h_start: " << valid_h_start; + // DLOG << "valid_h_end: " << valid_h_end; + // DLOG << "valid_w_start: " << valid_w_start; + // DLOG << "valid_w_end: " << valid_w_end; + #pragma omp parallel for - for (int g = 0; g < input_c; ++g) { - const int8_t* input_ptr = input_data + g * image_size; - const int8_t* filter_ptr = filter_data + g * 9; - int32_t* output_ptr = out_data + g * out_image_size; - int loops = output_w / 6; - int remain = output_w - loops * 6; - for (int h = 0; h < input_h - 6 /*(input_h - 1) - 5*/; h += 6) { - const int8_t* input_ptr0 = input_ptr + h * input_w; - const int8_t* input_ptr1 = input_ptr0 + input_w; - const int8_t* input_ptr2 = input_ptr1 + input_w; - const int8_t* input_ptr3 = input_ptr2 + input_w; - const int8_t* input_ptr4 = input_ptr3 + input_w; - const int8_t* input_ptr5 = input_ptr4 + input_w; - const int8_t* input_ptr6 = input_ptr5 + input_w; - int32_t* output_ptr0 = output_ptr + (h >> 1) * output_w; - int32_t* output_ptr1 = output_ptr0 + output_w; - int32_t* output_ptr2 = output_ptr1 + output_w; - int loop = loops; + for (int g = 0; g < input.dims()[1]; ++g) { + const int8_t *input_ptr = input_data + g * image_size; + const int8_t *filter_ptr = filter_data + g * 9; + int32_t *output_ptr = out_data + g * out_image_size; + // top + for (int h = 0; h < valid_h_start; ++h) { + DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr); + } + // left + for (int w = 0; w < valid_w_start; ++w) { + DepthwiseConv3x3ValidCol<2, 2>( + input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h, + input_w, padding_h, padding_w, output_w, output_ptr); + } + // right + for (int w = valid_w_end; w < output_w; ++w) { + DepthwiseConv3x3ValidCol<2, 2>( + input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h, + input_w, padding_h, padding_w, output_w, output_ptr); + } + // bottom + for (int h = valid_h_end; h < output_h; ++h) { + DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr); + } + // valid + int input_w_start = 2 * valid_w_start - padding_w; + int output_w_tiles = valid_w / 6; + int output_w_remain = valid_w - output_w_tiles * 6; + for (int h = valid_h_start; h < valid_h_end - 2; h += 3) { + size_t offset = (2 * h - padding_h) * input_w + input_w_start; + const int8_t *input_ptr0 = input_ptr + offset; + const int8_t *input_ptr1 = input_ptr0 + input_w; + const int8_t *input_ptr2 = input_ptr1 + input_w; + const int8_t *input_ptr3 = input_ptr2 + input_w; + const int8_t *input_ptr4 = input_ptr3 + input_w; + const int8_t *input_ptr5 = input_ptr4 + input_w; + const int8_t *input_ptr6 = input_ptr5 + input_w; + int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; + int32_t *output_ptr1 = output_ptr0 + output_w; + int32_t *output_ptr2 = output_ptr1 + output_w; + int loop = output_w_tiles; asm volatile( "vld1.32 {q0}, [%[filter_ptr]] \n" "vmovl.s8 q14, d0 \n" @@ -803,9 +1155,9 @@ void DepthwiseConv3x3s2(const framework::Tensor &input, : [filter_ptr] "r"(filter_ptr) : "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); asm volatile( - "mov r0, #12 \n" "cmp %[loop], #0 \n" "ble start_remain_%= \n" + "mov r0, #12 \n" // loop 6 widths "loop_3h6w_%=: \n" "vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n" @@ -1048,25 +1400,26 @@ void DepthwiseConv3x3s2(const framework::Tensor &input, "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" "vst1.32 {d28[0]}, [%[output_ptr2]]! \n" - "end_%=: \n" + "end_%=: \n" : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), [output_ptr2] "+r"(output_ptr2), [input_ptr6] "+r"(input_ptr6), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5), [loop] "+r"(loop) - : [remain] "r"(remain) + : [remain] "r"(output_w_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); } - int start_h = (output_h / 3) * 6; - for (int h = start_h; h < input_h - 2 /*(input_h - 1) - 1*/; h += 2) { - const int8_t* input_ptr0 = input_ptr + h * input_w; - const int8_t* input_ptr1 = input_ptr0 + input_w; - const int8_t* input_ptr2 = input_ptr1 + input_w; - int32_t* output_ptr0 = output_ptr + (h >> 1) * output_w; - int loop = loops; + int start_h = valid_h_start + valid_h / 3 * 3; + for (int h = start_h; h < valid_h_end; ++h) { + size_t offset = (2 * h - padding_h) * input_w + input_w_start; + const int8_t *input_ptr0 = input_ptr + offset; + const int8_t *input_ptr1 = input_ptr0 + input_w; + const int8_t *input_ptr2 = input_ptr1 + input_w; + int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; + int loop = output_w_tiles; asm volatile( "vld1.32 {q0}, [%[filter_ptr]] \n" "vmovl.s8 q14, d0 \n" @@ -1084,9 +1437,9 @@ void DepthwiseConv3x3s2(const framework::Tensor &input, : [filter_ptr] "r"(filter_ptr) : "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); asm volatile( - "mov r0, #12 \n" "cmp %[loop], #0 \n" "ble start_remain_%= \n" + "mov r0, #12 \n" // loop 6 widths "loop_1h6w_%=: \n" "vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n" @@ -1190,18 +1543,19 @@ void DepthwiseConv3x3s2(const framework::Tensor &input, "cmp %[remain], #1 \n" "blt end_%= \n" "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" - "end_%=: \n" + "end_%=: \n" : [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [loop] "+r"(loop) - : [remain] "r"(remain) + : [remain] "r"(output_w_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "r0"); } } -#endif // __aarch64__ } } // namespace math } // namespace operators } // namespace paddle_mobile + +#endif diff --git a/src/operators/op_param.h b/src/operators/op_param.h index c9477c4cd1167af6bd63d74c405dafeb6a8949e7..c4f5b180b832f320ac841f593ff76076b963f55d 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -405,9 +405,9 @@ class ConvParam : public OpParam { const RType *Input() const { return input_; } - RType *&Filter() const { return filter_; } + RType *Filter() const { return filter_; } - RType *&Output() const { return output_; } + RType *Output() const { return output_; } const vector &Strides() const { return strides_; } @@ -419,6 +419,8 @@ class ConvParam : public OpParam { EXEC_INVALID = 0, EXEC_GEMM_FLOAT, EXEC_DEPTHWISE3x3S1P1_FLOAT, + EXEC_DEPTHWISE3x3S2P0_FLOAT, + EXEC_DEPTHWISE3x3S2P1_FLOAT, EXEC_DEPTHWISE3x3_FLOAT, EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD5X5_FLOAT, @@ -439,8 +441,8 @@ class ConvParam : public OpParam { private: RType *input_; - mutable RType *output_; - mutable RType *filter_; + RType *output_; + RType *filter_; vector strides_; vector paddings_; vector dilations_; @@ -2573,7 +2575,9 @@ class DequantizeParam : public OpParam { DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) { input_ = InputXFrom(inputs, scope); - output_ = OutFrom(outputs, scope); + if (outputs.count("Out")) { + output_ = OutFrom(outputs, scope); + } activation_scale_ = OpParam::GetVarValue("Scale", inputs, scope); // dequantization is performed as x = x / static_scale / online_scale if (HasAttr("weight_scale", attrs)) { @@ -2593,20 +2597,19 @@ class DequantizeParam : public OpParam { }; #endif -#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +#if defined(FUSION_DEQUANT_ADD_BN_OP) || \ + defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || \ + defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_BN_OP) template -class FusionDequantAddBNReluParam : public DequantizeParam { +class FusionDequantBNParam : public DequantizeParam { typedef typename DtypeTensorTrait::gtype GType; typedef typename DtypeTensorTrait::rtype RType; public: - FusionDequantAddBNReluParam(const VariableNameMap &inputs, - const VariableNameMap &outputs, - const AttributeMap &attrs, const Scope &scope) + FusionDequantBNParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) : DequantizeParam(inputs, outputs, attrs, scope) { - // element wise add params - axis_ = OpParam::GetAttr("axis", attrs); - bias_ = OpParam::InputYFrom(inputs, scope); // batch norm params bn_mean_ = OpParam::GetVarValue("BNMean", inputs, scope); bn_variance_ = OpParam::GetVarValue("BNVariance", inputs, scope); @@ -2614,21 +2617,83 @@ class FusionDequantAddBNReluParam : public DequantizeParam { bn_bias_ = OpParam::GetVarValue("BNBias", inputs, scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); // output - output_ = OpParam::OutFrom(outputs, scope); + if (outputs.count("Y")) { + this->output_ = OpParam::OutputYFrom(outputs, scope); + } } public: - // elementwise add - int axis_; - RType *bias_; // batch norm RType *bn_mean_; RType *bn_variance_; RType *bn_scale_; RType *bn_bias_; float epsilon_; - // output - RType *output_; +}; +#endif + +#if defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_OP) +template +class FusionDequantAddBNParam : public FusionDequantBNParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantAddBNParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : FusionDequantBNParam(inputs, outputs, attrs, scope) { + // element wise add params + axis_ = OpParam::GetAttr("axis", attrs); + bias_ = OpParam::InputYFrom(inputs, scope); + // output + if (outputs.count("Y")) { + this->output_ = OpParam::OutputYFrom(outputs, scope); + } + } + + public: + // elementwise add + int axis_; + RType *bias_; +}; +#endif + +#ifdef FUSION_DEQUANT_BN_RELU_OP +template +class FusionDequantBNReluParam : public FusionDequantBNParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantBNReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : FusionDequantBNParam(inputs, outputs, attrs, scope) { + // output + if (outputs.count("Out")) { + this->output_ = OpParam::OutFrom(outputs, scope); + } + } +}; +#endif + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +template +class FusionDequantAddBNReluParam : public FusionDequantAddBNParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantAddBNReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : FusionDequantAddBNParam(inputs, outputs, attrs, scope) { + // output + if (outputs.count("Out")) { + this->output_ = OpParam::OutFrom(outputs, scope); + } + } }; #endif diff --git a/test/operators/test_quantize_op.cpp b/test/operators/test_quantize_op.cpp index 9988661bcb898daa5e79b6d22d65d90cfa03c668..50c0e7bd05da7f7a5ee1fd6912be0eff2f6e2958 100644 --- a/test/operators/test_quantize_op.cpp +++ b/test/operators/test_quantize_op.cpp @@ -44,25 +44,19 @@ struct Round { template <> struct Round { int8_t operator()(float x) { - int8_t ret = 0; float v = std::round(x); - int32_t q = (int32_t)v; - if (abs(abs(q - x) - 0.5) > 0) { - ret = q; - } else { - if (abs(q) % 2 == 0) { - ret = q; - } else { - ret = q + ((q > 0) ? -1 : 1); + int32_t q = static_cast(v); + if (abs(abs(q - v) - 0.5) <= 0) { + if (abs(q) % 2 != 0) { + q = q + ((q > 0) ? -1 : 1); } } - return ret; + return static_cast(q); } }; template -static void quantize(const Tensor *input, const float scale, const int pad, - const int8_t pad_val, Tensor *output) { +static void quantize(const Tensor *input, const float scale, Tensor *output) { int batch_size = input->dims()[0]; int channels = input->dims()[1]; int input_h = input->dims()[2]; @@ -77,29 +71,9 @@ static void quantize(const Tensor *input, const float scale, const int pad, for (int nc = 0; nc < batch_size * channels; ++nc) { const float *xh = x + nc * input_spatial; int8_t *yh = y + nc * output_spatial; - // pad top - for (int h = 0; h < pad; ++h, yh += output_w) { - for (int w = 0; w < output_w; ++w) { - yh[w] = pad_val; - } - } for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) { - // pad left - for (int w = 0; w < pad; ++w) { - yh[w] = pad_val; - } for (int w = 0; w < input_w; ++w) { - yh[w + pad] = Round()(xh[w] * scale); - } - // pad right - for (int w = 0; w < pad; ++w) { - yh[pad + input_w + w] = pad_val; - } - } - // pad bottom - for (int h = 0; h < pad; ++h, yh += output_w) { - for (int w = 0; w < output_w; ++w) { - yh[w] = pad_val; + yh[w] = Round()(xh[w] * scale); } } } @@ -120,19 +94,14 @@ static float find_abs_max(const Tensor *input) { int TestQuqntizeOp(int argc, char *argv[]) { if (argc < 5) { - std::cout - << "Usage: ./test-quantize-op batch_size channel height width [pad]" - << std::endl; + std::cout << "Usage: ./test-quantize-op batch_size channel height width" + << std::endl; return 1; } - int pad = 0; int batch_size = atoi(argv[1]); int channel = atoi(argv[2]); int height = atoi(argv[3]); int width = atoi(argv[4]); - if (argc == 6) { - pad = atoi(argv[5]); - } std::cout << "batch_size: " << batch_size << ", channel: " << channel << ", height: " << height << ", width: " << width << std::endl; framework::DDim dim = @@ -153,7 +122,6 @@ int TestQuqntizeOp(int argc, char *argv[]) { auto output_scale_var = scope.get()->Var("output_scale"); framework::AttributeMap attrs; - attrs["paddings"].Set>(std::vector({pad, pad})); auto *op = new operators::QuantizeOp("quantize", inputs, outputs, attrs, scope); op->InferShape(); @@ -172,9 +140,9 @@ int TestQuqntizeOp(int argc, char *argv[]) { framework::Tensor output_cmp; output_cmp.Resize(output->dims()); float scale = 127 / output_scale_cmp; - // quantize(input, scale, pad, 0, &output_cmp); - // quantize(input, scale, pad, 0, &output_cmp); - quantize(input, scale, pad, 0, &output_cmp); + // quantize(input, scale, &output_cmp); + // quantize(input, scale, &output_cmp); + quantize(input, scale, &output_cmp); int8_t *output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], diff --git a/tools/op.cmake b/tools/op.cmake index 98a5ce437ae6520a4cc27f9fceeadaeb30ba6e99..e2254c3261d53d142e77f09c001d9cbebb5f85ff 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -249,7 +249,9 @@ if(NOT FOUND_MATCH) set(SUM_OP ON) set(QUANT_OP ON) set(DEQUANT_OP ON) - set(FUSION_DEQUANT_ADD_BN_RELU ON) + set(FUSION_DEQUANT_ADD_BN_OP ON) + set(FUSION_DEQUANT_BN_RELU_OP ON) + set(FUSION_DEQUANT_ADD_BN_RELU_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -451,10 +453,17 @@ endif() if (DEQUANT_OP) add_definitions(-DDEQUANT_OP) endif() -if (FUSION_DEQUANT_ADD_BN_RELU) +if (FUSION_DEQUANT_ADD_BN_OP) + add_definitions(-DFUSION_DEQUANT_ADD_BN_OP) +endif() +if (FUSION_DEQUANT_BN_RELU_OP) + add_definitions(-DFUSION_DEQUANT_BN_RELU_OP) +endif() +if (FUSION_DEQUANT_ADD_BN_RELU_OP) add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP) endif() + if (TANH_OP) add_definitions(-DTANH_OP) endif() @@ -467,3 +476,4 @@ endif() if (FUSION_DECONVADDRELU_OP) add_definitions(-DFUSION_DECONVADDRELU_OP) endif() +