diff --git a/src/common/types.cpp b/src/common/types.cpp index a6f32762d3c8a492c3347ebfe65cb50f39425976..1611a919ffe716e31ff83b8673035a048ebe96d2 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -23,8 +23,9 @@ const std::string G_OP_TYPE_BOX_CODER = "box_coder"; const std::string G_OP_TYPE_CONCAT = "concat"; const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; -const std::string G_OP_TYPE_FC = "fc"; -const std::string G_OP_TYPE_CONV_ADD = "conv_add"; +const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu"; +const std::string G_OP_TYPE_FC = "fusion_fc"; +const std::string G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add"; const std::string G_OP_TYPE_LRN = "lrn"; const std::string G_OP_TYPE_MUL = "mul"; const std::string G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms"; @@ -44,7 +45,7 @@ std::unordered_map< std::string, std::pair, std::vector>> op_input_output_key = { {G_OP_TYPE_CONV, {{"Input"}, {"Output"}}}, - {G_OP_TYPE_CONV_ADD, {{"Input"}, {"Out"}}}, + {G_OP_TYPE_FUSION_CONV_ADD, {{"Input"}, {"Out"}}}, {G_OP_TYPE_RELU, {{"X"}, {"Out"}}}, {G_OP_TYPE_SOFTMAX, {{"X"}, {"Out"}}}, {G_OP_TYPE_MUL, {{"X"}, {"Out"}}}, @@ -59,6 +60,7 @@ std::unordered_map< {G_OP_TYPE_TRANSPOSE, {{"X"}, {"Out"}}}, {G_OP_TYPE_BOX_CODER, {{"PriorBox", "PriorBoxVar", "TargetBox"}, {"OutputBox"}}}, + {G_OP_TYPE_FUSION_CONV_ADD_BN_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_PRIOR_BOX, {{"Image", "Input"}, {"Boxes", "Variances"}}}, {G_OP_TYPE_MULTICLASS_NMS, {{"BBoxes", "Scores"}, {"Out"}}}, {G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index 49f0c49a585ac45cbb0a061f72e33f2fb579a82e..9134ebe3561153e32db157c4c4b835a1bc464149 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -79,7 +79,9 @@ extern const std::string G_OP_TYPE_CONCAT; extern const std::string G_OP_TYPE_ELEMENTWISE_ADD; extern const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU; extern const std::string G_OP_TYPE_FC; -extern const std::string G_OP_TYPE_CONV_ADD; +extern const std::string G_OP_TYPE_FUSION_CONV_ADD; +extern const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU; + extern const std::string G_OP_TYPE_LRN; extern const std::string G_OP_TYPE_MUL; extern const std::string G_OP_TYPE_MULTICLASS_NMS; diff --git a/src/framework/operator.h b/src/framework/operator.h index c68744a676030413e81570ded0db5671cdf4ba7a..793551b0cd3eea290243c156c27616a34c37a3d2 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -63,7 +63,7 @@ class OperatorBase { std::vector GetOutKeys() const; virtual void RunImpl() const = 0; - virtual void Init() const = 0; + virtual void Init() = 0; /* * @b op 运算所需的输入, 如上一层的输出结果、卷积核 * */ @@ -117,8 +117,8 @@ class OperatorWithKernel : public OperatorBase { virtual void InferShape() const = 0; - void Init() const { - PADDLE_MOBILE_ENFORCE(kernel_.Init(param_), " %s kernel init failed", + void Init() { + PADDLE_MOBILE_ENFORCE(kernel_.Init(¶m_), " %s kernel init failed", this->type_.c_str()); } @@ -146,7 +146,7 @@ class OpKernelBase { } #endif virtual void Compute(const P ¶) const = 0; - virtual bool Init(const P ¶) const { return true; }; + virtual bool Init(P *para) { return true; }; virtual ~OpKernelBase() = default; private: diff --git a/src/framework/program/program-optimize/node.cpp b/src/framework/program/program-optimize/node.cpp index 89385e12d9c5f20a21f6ee6f3987c088c4b15563..e635e07eaf4484c3e390101c3b43fdaf24bbd2c6 100644 --- a/src/framework/program/program-optimize/node.cpp +++ b/src/framework/program/program-optimize/node.cpp @@ -93,7 +93,8 @@ int Node::Depth(int begin) { Node &Node::Folder( int size, std::string type, - std::map> change, + std::map>> + change, std::vector> *removed_nodes) { std::shared_ptr op_desc = std::make_shared(); @@ -110,12 +111,15 @@ Node &Node::Folder( void Node::Folder( std::shared_ptr op_desc, std::vector> *outputs, int index, - std::map> *change, + std::map>> + *change, Node *begin_node, std::vector> *removed_nodes) { if (change->find(this->type_) != change->end()) { - auto change_pair = (*change)[this->type_]; - op_desc->GetInputs()[change_pair.second] = - this->op_desc_->GetInputs()[change_pair.first]; + auto change_pairs = (*change)[this->type_]; + for (const auto &change_pair : change_pairs) { + op_desc->GetInputs()[change_pair.second] = + this->op_desc_->GetInputs()[change_pair.first]; + } } for (auto &attr_pair : this->op_desc_->attrs_) { diff --git a/src/framework/program/program-optimize/node.h b/src/framework/program/program-optimize/node.h index 7236ffdd1782dfb39af73195da9b3756030c9117..88bf1e16ed2a5fb3a038eadd546d63ffb3916f68 100644 --- a/src/framework/program/program-optimize/node.h +++ b/src/framework/program/program-optimize/node.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include "common/log.h" #include "framework/program/op_desc.h" @@ -43,7 +44,8 @@ class Node { int Depth(int begin = 0); Node &Folder( int size, std::string type, - std::map> change_map, + std::map>> + change, std::vector> *removed_nodes); std::vector> OpDescs(int size); std::shared_ptr OpDescOfNode() { return op_desc_; } @@ -56,7 +58,8 @@ class Node { void Folder( std::shared_ptr op_desc, std::vector> *outputs, int index, - std::map> *change, + std::map>> + *change, Node *begin_node, std::vector> *removed_nodes); std::shared_ptr op_desc_; #ifdef PADDLE_MOBILE_DEBUG diff --git a/src/operators/feed_op.h b/src/operators/feed_op.h index bd5fd8cb32d484b7f76652139603f6b0f1b4b5d7..8753bfa9375f50930f9ec57e1b48b26c127edbc6 100644 --- a/src/operators/feed_op.h +++ b/src/operators/feed_op.h @@ -32,7 +32,7 @@ class FeedOp : public framework::OperatorBase { param_(inputs, outputs, attrs, *scope) {} void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } - void Init() const {} + void Init() {} void InferShape() const { auto out_dims = param_.Out()->dims(); diff --git a/src/operators/fetch_op.h b/src/operators/fetch_op.h index 4b3680b58357d8295b1b6acf111d3573d4e4d1bd..b46093e18e1d92ed9dacbdb456bb591d0c546456 100644 --- a/src/operators/fetch_op.h +++ b/src/operators/fetch_op.h @@ -33,7 +33,7 @@ class FetchOp : public framework::OperatorBase { param_(inputs, outputs, attrs, *scope) {} void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } - void Init() const {} + void Init() {} void InferShape() const { auto x_dims = param_.InputX()->dims(); diff --git a/src/operators/fusion_conv_add.cpp b/src/operators/fusion_conv_add.cpp index 4c01603509b0a1d9da2c2dc31a38719d5117e05c..731bb66cb058dd8562b5fc9257bd8e9ed5f9c0af 100644 --- a/src/operators/fusion_conv_add.cpp +++ b/src/operators/fusion_conv_add.cpp @@ -50,8 +50,8 @@ template class FusionConvAddOp; namespace ops = paddle_mobile::operators; #ifdef PADDLE_MOBILE_CPU -USE_OP_CPU(conv_add); -REGISTER_OPERATOR_CPU(conv_add, ops::FusionConvAddOp); +USE_OP_CPU(fusion_conv_add); +REGISTER_OPERATOR_CPU(fusion_conv_add, ops::FusionConvAddOp); #endif #ifdef PADDLE_MOBILE_MALI_GPU USE_OP_MALI_GPU(conv_add); diff --git a/src/operators/fusion_conv_add.h b/src/operators/fusion_conv_add.h index 73107a3c0adc382dea98663188215ad295c4506b..8b843f55266300b9fbb758b2b5ce43b908d1dc82 100644 --- a/src/operators/fusion_conv_add.h +++ b/src/operators/fusion_conv_add.h @@ -39,10 +39,10 @@ class FusionConvAddMatcher : public framework::FusionOpMatcher { vector> origin_descs = node->OpDescs(node_.Depth()); node->Folder(node_.Depth(), Type(), - {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Y"}}}, removed_nodes); + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes); } - std::string Type() { return G_OP_TYPE_CONV_ADD; } + std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD; } }; template @@ -67,11 +67,13 @@ class FusionConvAddOp : public framework::OperatorWithKernel< }; #ifdef PADDLE_MOBILE_CPU + #ifndef CONV_ADD_REGISTER static framework::FusionOpRegistrar convadd_registrar( new FusionConvAddMatcher()); #define CONV_ADD_REGISTER #endif + #endif #ifdef PADDLE_MOBILE_MALI_GPU diff --git a/src/operators/fusion_conv_add_bn_relu_op.cpp b/src/operators/fusion_conv_add_bn_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..63d0b23444ae6bf625e5e8640d3dc2ad314917d2 --- /dev/null +++ b/src/operators/fusion_conv_add_bn_relu_op.cpp @@ -0,0 +1,61 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDBNRELU_OP + +#include "operators/fusion_conv_add_bn_relu_op.h" +#include "operators/math/conv_func.h" + +namespace paddle_mobile { +namespace operators { + +template +void FusionConvAddBNReluOp::InferShape() const { + auto in_dims = this->param_.Input()->dims(); + auto filter_dims = this->param_.Filter()->dims(); + const std::vector &strides = this->param_.Strides(); + std::vector paddings = this->param_.Paddings(); + int groups = this->param_.Groups(); + std::vector dilations = this->param_.Dilations(); + + PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && + dilations.size() == paddings.size() && + paddings.size() == strides.size()), + "ConvParam is not suitable"); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back( + math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], + paddings[i], strides[i])); + } + + framework::DDim ddim = framework::make_ddim(output_shape); + this->param_.Output()->Resize(ddim); +} +template class FusionConvAddBNReluOp; +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +USE_OP_CPU(fusion_conv_add_bn_relu); +REGISTER_OPERATOR_CPU(fusion_conv_add_bn_relu, ops::FusionConvAddBNReluOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/fusion_conv_add_bn_relu_op.h b/src/operators/fusion_conv_add_bn_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..494e49280dbdc3fe778cd7bdf5f5d30a82f2d9ff --- /dev/null +++ b/src/operators/fusion_conv_add_bn_relu_op.h @@ -0,0 +1,106 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDBNRELU_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "op_param.h" +#include "operators/kernel/conv_add_bn_relu_kernel.h" + +namespace paddle_mobile { +namespace operators { +using std::string; +using std::vector; +class FusionConvAddBNReluMatcher : public framework::FusionOpMatcher { + public: + FusionConvAddBNReluMatcher() { + node_ = framework::Node(G_OP_TYPE_CONV); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + vector> origin_descs = + node->OpDescs(node_.Depth()); + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}, + {G_OP_TYPE_BATCHNORM, + {{"Scale", "Scale"}, + {"Mean", "Mean"}, + {"Bias", "Bias"}, + {"Variance", "Variance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_BN_RELU; } +}; + +template +class FusionConvAddBNReluOp + : public framework::OperatorWithKernel< + DeviceType, FusionConvAddBNReluParam, + operators::ConvAddBNReluKernel> { + public: + FusionConvAddBNReluOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionConvAddBNReluParam, + operators::ConvAddBNReluKernel>( + type, inputs, outputs, attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, FusionConvAddBNReluParam, + operators::ConvAddBNReluKernel>::OperatorWithKernel; + void InferShape() const override; + + protected: +}; + +#ifdef PADDLE_MOBILE_CPU + +//#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER +// static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( +// new FusionConvAddBNReluMatcher()); +//#define FUSION_CONV_ADD_BN_RELU_REGISTER +//#endif + +#endif + +#ifdef PADDLE_MOBILE_MALI_GPU + +#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER +static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( + new FusionConvAddBNReluMatcher()); +#define FUSION_CONV_ADD_BN_RELU_REGISTER +#endif + +#endif + +#ifdef PADDLE_MOBILE_FPGA +#endif + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/fusion_conv_add_relu_op.h b/src/operators/fusion_conv_add_relu_op.h index fd27005c8bef8f8cb91fbf5b6e5a852306c28a9b..bcacb3da3e2ec5371021f3552ffd2c9f53947874 100644 --- a/src/operators/fusion_conv_add_relu_op.h +++ b/src/operators/fusion_conv_add_relu_op.h @@ -36,7 +36,7 @@ class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher { framework::Node *node, std::vector> *removed_nodes) { node->Folder(node_.Depth(), Type(), - {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Y"}}}, removed_nodes); + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes); } std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_RELU; } }; @@ -65,11 +65,11 @@ class FusionConvAddReluOp : public framework::OperatorWithKernel< #ifdef PADDLE_MOBILE_CPU -#ifndef CONV_ADD_RELU_REGISTER -#define CONV_ADD_RELU_REGISTER +//#ifndef CONV_ADD_RELU_REGISTER +//#define CONV_ADD_RELU_REGISTER // static framework::FusionOpRegistrar fusion_conv_add_relu_registrar(new // FusionConvAddReluOpMatcher()); -#endif +//#endif #endif #ifdef PADDLE_MOBILE_MALI_GPU diff --git a/src/operators/fusion_fc_op.cpp b/src/operators/fusion_fc_op.cpp index fae561348899dadc4c25f84ec3a0993d9ae693f9..2e591b678cf7987eba5fdc74643cd7e15c35271f 100644 --- a/src/operators/fusion_fc_op.cpp +++ b/src/operators/fusion_fc_op.cpp @@ -55,8 +55,8 @@ template class FusionFcOp; namespace ops = paddle_mobile::operators; #ifdef PADDLE_MOBILE_CPU -USE_OP_CPU(fc); -REGISTER_OPERATOR_CPU(fc, ops::FusionFcOp); +USE_OP_CPU(fusion_fc); +REGISTER_OPERATOR_CPU(fusion_fc, ops::FusionFcOp); #endif #ifdef PADDLE_MOBILE_MALI_GPU USE_OP_MALI_GPU(fc); diff --git a/src/operators/fusion_fc_op.h b/src/operators/fusion_fc_op.h index 0ca4d2b27ad46b77ddba55b6b377e741c97bdc9e..ea1f42f0adfb532982f50c2da41fc58f63b54834 100644 --- a/src/operators/fusion_fc_op.h +++ b/src/operators/fusion_fc_op.h @@ -38,7 +38,7 @@ class FusionFcMatcher : public framework::FusionOpMatcher { framework::Node *node, std::vector> *removed_nodes) { node->Folder(node_.Depth(), Type(), - {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}, removed_nodes); + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Z"}}}}, removed_nodes); } std::string Type() { return G_OP_TYPE_FC; } @@ -66,17 +66,21 @@ class FusionFcOp }; #ifdef PADDLE_MOBILE_CPU + #ifndef CONV_CPU_REGISTER #define CONV_CPU_REGISTER static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher()); #endif + #endif #ifdef PADDLE_MOBILE_MALI_GPU + #ifndef CONV_CPU_REGISTER #define CONV_CPU_REGISTER static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher()); #endif + #endif #ifdef PADDLE_MOBILE_FPGA diff --git a/src/operators/kernel/arm/batchnorm_kernel.cpp b/src/operators/kernel/arm/batchnorm_kernel.cpp index 964bf71f451e2ca48d3742ed5151e9784c516d5c..f78d1fdc95ac9e10619dbf32fdc84d01a370f315 100644 --- a/src/operators/kernel/arm/batchnorm_kernel.cpp +++ b/src/operators/kernel/arm/batchnorm_kernel.cpp @@ -21,7 +21,7 @@ namespace paddle_mobile { namespace operators { template <> -bool BatchNormKernel::Init(const BatchNormParam ¶) const { +bool BatchNormKernel::Init(BatchNormParam *param) { return true; } diff --git a/src/operators/kernel/arm/box_coder_kernel.cpp b/src/operators/kernel/arm/box_coder_kernel.cpp index df0a75f357658736ede4265a6cc57db30afee1d4..fb113b16f53bcd1b9fca7a1dbbf94a846e9a0f81 100644 --- a/src/operators/kernel/arm/box_coder_kernel.cpp +++ b/src/operators/kernel/arm/box_coder_kernel.cpp @@ -111,7 +111,7 @@ void DecodeCenterSize(const framework::Tensor& target_box, } template <> -bool BoxCoderKernel::Init(const BoxCoderParam& para) const { +bool BoxCoderKernel::Init(BoxCoderParam* param) { return true; } diff --git a/src/operators/kernel/arm/concat_kernel.cpp b/src/operators/kernel/arm/concat_kernel.cpp index 0312047b8e8af1eb9dad57c751e392e8a5054878..5daf3e104a04025165ce7281f3a16d8e3f9cb522 100644 --- a/src/operators/kernel/arm/concat_kernel.cpp +++ b/src/operators/kernel/arm/concat_kernel.cpp @@ -53,7 +53,7 @@ class ConcatFunctor { }; template <> -bool ConcatKernel::Init(const ConcatParam ¶) const { +bool ConcatKernel::Init(ConcatParam *param) { return true; } diff --git a/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e95bd8e76c5034f3897eff81e0ba67119d04a95b --- /dev/null +++ b/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp @@ -0,0 +1,65 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDBNRELU_OP + +#include "operators/kernel/conv_add_bn_relu_kernel.h" +#include "operators/kernel/central-arm-func/conv_add_bn_relu_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ConvAddBNReluKernel::Init(FusionConvAddBNReluParam *param) { + const Tensor *mean = (*param).InputMean(); + const Tensor *variance = (*param).InputVariance(); + const Tensor *scale = (*param).InputScale(); + const Tensor *bias = (*param).InputBias(); + const float epsilon = (*param).Epsilon(); + + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + const int C = mean->numel(); + float inv_std_ptr[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + Tensor *new_scale = new Tensor(); + Tensor *new_bias = new Tensor(); + auto new_scale_ptr = new_scale->mutable_data({C}); + auto new_bias_ptr = new_bias->mutable_data({C}); + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + } + (*param).SetNewScale(new_scale); + (*param).SetNewBias(new_bias); + return true; +} + +template <> +void ConvAddBNReluKernel::Compute( + const FusionConvAddBNReluParam ¶m) const { + ConvAddBNReluCompute(param); +} +template class ConvAddBNReluKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_add_kernel.cpp b/src/operators/kernel/arm/conv_add_kernel.cpp index 2c7aef932dc68e7a29bf60760751be0f9598cd42..64d6dfa64dc3feae5b73a17ae5b148053df34a0b 100644 --- a/src/operators/kernel/arm/conv_add_kernel.cpp +++ b/src/operators/kernel/arm/conv_add_kernel.cpp @@ -19,7 +19,7 @@ namespace paddle_mobile { namespace operators { template <> -bool ConvAddKernel::Init(const FusionConvAddParam ¶) const { +bool ConvAddKernel::Init(FusionConvAddParam *param) { return true; } diff --git a/src/operators/kernel/arm/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/conv_add_relu_kernel.cpp index d3c04179b37014adc6c81f32dd6c08f697283671..356dd191e761afc5d5b6bfacd250f90ae31017b2 100644 --- a/src/operators/kernel/arm/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/conv_add_relu_kernel.cpp @@ -21,8 +21,7 @@ namespace paddle_mobile { namespace operators { template <> -bool ConvAddReluKernel::Init( - const FusionConvAddReluParam ¶) const { +bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { return true; } diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index 049425d88f96a322a0b4cb47c18d85f2df03d577..ca8aeff0dd3db5fe7b625bdeb947b2927eb619ce 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -21,7 +21,7 @@ namespace paddle_mobile { namespace operators { template <> -bool ConvKernel::Init(const ConvParam ¶) const { +bool ConvKernel::Init(ConvParam *param) { return true; } diff --git a/src/operators/kernel/arm/depthwise_conv_kernel.cpp b/src/operators/kernel/arm/depthwise_conv_kernel.cpp index 4cbfa23248e87e2bf3a8d97330fa19f92985a9d0..6ede0e2bef2383df8aa0593a07297f2f6233acaf 100644 --- a/src/operators/kernel/arm/depthwise_conv_kernel.cpp +++ b/src/operators/kernel/arm/depthwise_conv_kernel.cpp @@ -21,7 +21,7 @@ namespace paddle_mobile { namespace operators { template <> -bool DepthwiseConvKernel::Init(const ConvParam ¶) const { +bool DepthwiseConvKernel::Init(ConvParam *param) { return true; } diff --git a/src/operators/kernel/arm/elementwise_add_kernel.cpp b/src/operators/kernel/arm/elementwise_add_kernel.cpp index 2f5e26a37e4f2c1d370805ee7b565a60f4748b0a..bd9bb26d299bd340074965e41e5658df86bab347 100644 --- a/src/operators/kernel/arm/elementwise_add_kernel.cpp +++ b/src/operators/kernel/arm/elementwise_add_kernel.cpp @@ -27,8 +27,7 @@ struct AddFunctor { }; template <> -bool ElementwiseAddKernel::Init( - const ElementwiseAddParam ¶) const { +bool ElementwiseAddKernel::Init(ElementwiseAddParam *param) { return true; } diff --git a/src/operators/kernel/arm/fusion_fc_kernel.cpp b/src/operators/kernel/arm/fusion_fc_kernel.cpp index 5fac70e40781593669abd15b8f28ff6272f7133c..e10f11c0b19edf710ffc45f199f096bea0a34b7d 100644 --- a/src/operators/kernel/arm/fusion_fc_kernel.cpp +++ b/src/operators/kernel/arm/fusion_fc_kernel.cpp @@ -22,7 +22,7 @@ namespace paddle_mobile { namespace operators { template <> -bool FusionFcKernel::Init(const FusionFcParam ¶) const { +bool FusionFcKernel::Init(FusionFcParam *param) { return true; } diff --git a/src/operators/kernel/arm/lrn_kernel.cpp b/src/operators/kernel/arm/lrn_kernel.cpp index 839c5ee95bd4d1e9d3fd80af3df0f8a45797434e..356aa388276d9d0359b1a6b3a45c86bcb822fd9e 100644 --- a/src/operators/kernel/arm/lrn_kernel.cpp +++ b/src/operators/kernel/arm/lrn_kernel.cpp @@ -22,7 +22,7 @@ namespace paddle_mobile { namespace operators { template <> -bool LrnKernel::Init(const LrnParam ¶) const { +bool LrnKernel::Init(LrnParam *param) { return true; } diff --git a/src/operators/kernel/arm/mul_kernel.cpp b/src/operators/kernel/arm/mul_kernel.cpp index b3bb2b8075fdf306d47640c2bee3f2fc00ef0bc0..99b6576d364671be78efa9a8f2ebf85a6e133f33 100644 --- a/src/operators/kernel/arm/mul_kernel.cpp +++ b/src/operators/kernel/arm/mul_kernel.cpp @@ -22,7 +22,7 @@ namespace paddle_mobile { namespace operators { template <> -bool MulKernel::Init(const MulParam ¶) const { +bool MulKernel::Init(MulParam *param) { return true; } diff --git a/src/operators/kernel/arm/multiclass_nms_kernel.cpp b/src/operators/kernel/arm/multiclass_nms_kernel.cpp index 67cf8197ca4c3113fc4fde3d493d6ed209221b59..ecdc60f77b0cad3af2e8b026ab3666394dc43fee 100644 --- a/src/operators/kernel/arm/multiclass_nms_kernel.cpp +++ b/src/operators/kernel/arm/multiclass_nms_kernel.cpp @@ -204,8 +204,7 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, } template <> -bool MultiClassNMSKernel::Init( - const MultiClassNMSParam& para) const { +bool MultiClassNMSKernel::Init(MultiClassNMSParam* param) { return true; } diff --git a/src/operators/kernel/arm/pool_kernel.cpp b/src/operators/kernel/arm/pool_kernel.cpp index 09162a13a4d0c59220cc25a02d06369c3f21ed32..5c92d5be014faf4007c0853bde08e450ebc4f79a 100644 --- a/src/operators/kernel/arm/pool_kernel.cpp +++ b/src/operators/kernel/arm/pool_kernel.cpp @@ -36,7 +36,7 @@ inline void PoolBasic(std::string pooling_type, std::vector ksize, } template <> -bool PoolKernel::Init(const PoolParam ¶) const { +bool PoolKernel::Init(PoolParam *param) { return true; } diff --git a/src/operators/kernel/arm/prior_box_kernel.cpp b/src/operators/kernel/arm/prior_box_kernel.cpp index 13939bc7bf27904405677560f17d2e0b85748310..32d3818ef244e4c2879167b4273b0538eef08c56 100644 --- a/src/operators/kernel/arm/prior_box_kernel.cpp +++ b/src/operators/kernel/arm/prior_box_kernel.cpp @@ -27,7 +27,7 @@ struct ClipFunctor { }; template <> -bool PriorBoxKernel::Init(const PriorBoxParam ¶) const { +bool PriorBoxKernel::Init(PriorBoxParam *param) { return true; } diff --git a/src/operators/kernel/arm/relu_kernel.cpp b/src/operators/kernel/arm/relu_kernel.cpp index 5bc485b77a8fac9379adbd1a3bd4d406e5a82fcb..f6480dea75289cf6615a9737acfd913a3cb13008 100644 --- a/src/operators/kernel/arm/relu_kernel.cpp +++ b/src/operators/kernel/arm/relu_kernel.cpp @@ -26,7 +26,7 @@ struct ReluFunctor { }; template <> -bool ReluKernel::Init(const ReluParam ¶) const { +bool ReluKernel::Init(ReluParam *param) { return true; } diff --git a/src/operators/kernel/arm/reshape_kernel.cpp b/src/operators/kernel/arm/reshape_kernel.cpp index 97364f9a3f7ce9fe8da5814ad2a483f858938bbf..9e0fd96d3ecd9772ef6e95bc12bb071a25a1d84a 100644 --- a/src/operators/kernel/arm/reshape_kernel.cpp +++ b/src/operators/kernel/arm/reshape_kernel.cpp @@ -20,7 +20,7 @@ namespace paddle_mobile { namespace operators { template <> -bool ReshapeKernel::Init(const ReshapeParam ¶) const { +bool ReshapeKernel::Init(ReshapeParam *param) { return true; } diff --git a/src/operators/kernel/arm/sigmoid_kernel.cpp b/src/operators/kernel/arm/sigmoid_kernel.cpp index 3e87bfacc5335e52ecdcb0b917f5826b80449ef4..5eb65cd6cebf453e46dc16c4982f81cb679bbc72 100644 --- a/src/operators/kernel/arm/sigmoid_kernel.cpp +++ b/src/operators/kernel/arm/sigmoid_kernel.cpp @@ -72,7 +72,7 @@ void sigmoid(const Tensor *X, Tensor *Y) { } template <> -bool SigmoidKernel::Init(const SigmoidParam ¶) const { +bool SigmoidKernel::Init(SigmoidParam *param) { return true; } diff --git a/src/operators/kernel/arm/softmax_kernel.cpp b/src/operators/kernel/arm/softmax_kernel.cpp index 8e966aa0af9ac84b70b154b33bad7dad9e79121d..29006d48dc00b650a725cd0a9cc3c37568e829a9 100644 --- a/src/operators/kernel/arm/softmax_kernel.cpp +++ b/src/operators/kernel/arm/softmax_kernel.cpp @@ -20,7 +20,7 @@ namespace paddle_mobile { namespace operators { template <> -bool SoftmaxKernel::Init(const SoftmaxParam ¶) const { +bool SoftmaxKernel::Init(SoftmaxParam *param) { return true; } diff --git a/src/operators/kernel/arm/transpose_kernel.cpp b/src/operators/kernel/arm/transpose_kernel.cpp index a44ff22a2f228cc357c066a01e142de7cc4f2083..f697d4ca473d64b834fe1451afd8e0df7f84b3a6 100644 --- a/src/operators/kernel/arm/transpose_kernel.cpp +++ b/src/operators/kernel/arm/transpose_kernel.cpp @@ -35,7 +35,7 @@ namespace operators { // } template <> -bool TransposeKernel::Init(const TransposeParam& para) const { +bool TransposeKernel::Init(TransposeParam* param) { return true; } diff --git a/src/operators/kernel/batchnorm_kernel.h b/src/operators/kernel/batchnorm_kernel.h index 6ef5329bc58fea8bfc17d9115b7004fed2bc4ed7..367dd0996c0df5fba7c3570285cf5e2cfd3fac99 100644 --- a/src/operators/kernel/batchnorm_kernel.h +++ b/src/operators/kernel/batchnorm_kernel.h @@ -29,7 +29,7 @@ class BatchNormKernel : public framework::OpKernelBase { public: void Compute(const BatchNormParam ¶m) const; - bool Init(const BatchNormParam ¶) const; + bool Init(BatchNormParam *param); }; } // namespace operators diff --git a/src/operators/kernel/box_coder_kernel.h b/src/operators/kernel/box_coder_kernel.h index 4c4206f52b3ffc5e60983bf1d6adb25292d01ac4..2ad63ecd90a07d955c3e239277ac1bd60f3510bb 100644 --- a/src/operators/kernel/box_coder_kernel.h +++ b/src/operators/kernel/box_coder_kernel.h @@ -30,7 +30,7 @@ class BoxCoderKernel : public framework::OpKernelBase { public: void Compute(const BoxCoderParam& param) const; - bool Init(const BoxCoderParam& para) const; + bool Init(BoxCoderParam* param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h new file mode 100644 index 0000000000000000000000000000000000000000..6fce2c26347183c47ae756b07156de76f37ea6e5 --- /dev/null +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h @@ -0,0 +1,136 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDBNRELU_OP + +#pragma once +#include "operators/kernel/conv_add_bn_relu_kernel.h" +#include "operators/math/depthwise_conv_3x3.h" +#include "operators/op_param.h" +namespace paddle_mobile { +namespace operators { + +template +void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor bias = *param.Bias(); + Tensor new_bias = *param.NewBias(); + Tensor new_scale = *param.NewScale(); + auto new_bias_ptr = new_bias.data(); + auto new_scale_ptr = new_scale.data(); + int axis = param.Axis(); + int groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + Tensor *output = param.Output(); + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + if (filter_shape_vec[2] == 3 && strides[0] == 1 && groups > 1) { + math::DepthwiseConvAddBNRelu3x3s1p1(input, filter, output, &bias, 1, + &new_scale, &new_bias, 1, 1); + } else { + const int batch_size = static_cast(input->dims()[0]); + + math::expand_bias(bias, axis, output->dims()); + output->ShareDataWith(bias); + + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = + math::IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(1), false); + } + } + + auto output_ptr = output->data(); + for (int c = 0; c < output_matrix_shape[0]; c++) { + int start = c * output_matrix_shape[1]; + for (int j = 0; j < output_matrix_shape[1]; j++) { + output_ptr[start + j] = + output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c]; + output_ptr[start + j] = + output_ptr[start + j] < 0 ? 0 : output_ptr[start + j]; + } + } + } +} +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/concat_kernel.h b/src/operators/kernel/concat_kernel.h index 6a7b7c6005b6e85e5b1ccfee713672b6e333b98a..adba64391e3e79569030c95e2d2681a31187f03a 100644 --- a/src/operators/kernel/concat_kernel.h +++ b/src/operators/kernel/concat_kernel.h @@ -27,7 +27,7 @@ template class ConcatKernel : public framework::OpKernelBase { public: void Compute(const ConcatParam ¶m) const; - bool Init(const ConcatParam ¶) const; + bool Init(ConcatParam *param); }; } // namespace operators diff --git a/src/operators/math/depthwiseconv3x3s1p1.h b/src/operators/kernel/conv_add_bn_relu_kernel.h similarity index 53% rename from src/operators/math/depthwiseconv3x3s1p1.h rename to src/operators/kernel/conv_add_bn_relu_kernel.h index 019237a43192f30dfb70fe85e6b16a835cba4eba..73aaf4c900393b9cbee4682fc67147d9ef0853fc 100644 --- a/src/operators/math/depthwiseconv3x3s1p1.h +++ b/src/operators/kernel/conv_add_bn_relu_kernel.h @@ -13,15 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "framework/tensor.h" + +#ifdef FUSION_CONVADDBNRELU_OP + +#include +#include "framework/ddim.h" +#include "framework/operator.h" +#include "operators/math/conv_func.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" +#include "operators/op_param.h" namespace paddle_mobile { namespace operators { -namespace math { -using framework::Tensor; -void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output, - Tensor bias, bool if_bias); -} // namespace math +using framework::DDim; +using framework::OpKernelBase; + +template +class ConvAddBNReluKernel + : public OpKernelBase { + public: + void Compute(const FusionConvAddBNReluParam ¶m) const; + bool Init(FusionConvAddBNReluParam *param); +}; + } // namespace operators } // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/conv_add_kernel.h b/src/operators/kernel/conv_add_kernel.h index fb161238fee0550a42cd62cc132d6e8dbf45872f..465d8bdd8cfd71d678eb2816cae10ea6a06cec35 100644 --- a/src/operators/kernel/conv_add_kernel.h +++ b/src/operators/kernel/conv_add_kernel.h @@ -40,7 +40,7 @@ template class ConvAddKernel : public OpKernelBase { public: void Compute(const FusionConvAddParam ¶m) const; - bool Init(const FusionConvAddParam ¶) const; + bool Init(FusionConvAddParam *param); }; } // namespace operators diff --git a/src/operators/kernel/conv_add_relu_kernel.h b/src/operators/kernel/conv_add_relu_kernel.h index 9b86cd22e82e641ee6cb0a15bd25c8a1c6cbe8cb..3f36d80c4781aebea756b04e340d056a79cfd7d7 100644 --- a/src/operators/kernel/conv_add_relu_kernel.h +++ b/src/operators/kernel/conv_add_relu_kernel.h @@ -36,7 +36,7 @@ class ConvAddReluKernel : public OpKernelBase { public: void Compute(const FusionConvAddReluParam ¶m) const; - bool Init(const FusionConvAddReluParam ¶) const; + bool Init(FusionConvAddReluParam *param); }; } // namespace operators diff --git a/src/operators/kernel/conv_kernel.h b/src/operators/kernel/conv_kernel.h index 812ddd5a441f3a24c557546c1780248a557a6eb0..fedbee32a006f263fd3de25064496dad1a23177b 100644 --- a/src/operators/kernel/conv_kernel.h +++ b/src/operators/kernel/conv_kernel.h @@ -32,7 +32,7 @@ template class ConvKernel : public OpKernelBase { public: void Compute(const ConvParam ¶m) const; - bool Init(const ConvParam ¶) const; + bool Init(ConvParam *param); }; } // namespace operators diff --git a/src/operators/kernel/depthwise_conv_kernel.h b/src/operators/kernel/depthwise_conv_kernel.h index a8a8fb338620477670477703018bf9e6e9a8a604..b74a58a649bd9fa27e941e2cd5ea50b30c0218cb 100644 --- a/src/operators/kernel/depthwise_conv_kernel.h +++ b/src/operators/kernel/depthwise_conv_kernel.h @@ -31,7 +31,7 @@ template class DepthwiseConvKernel : public OpKernelBase { public: void Compute(const ConvParam ¶m) const; - bool Init(const ConvParam ¶) const; + bool Init(ConvParam *param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/elementwise_add_kernel.h b/src/operators/kernel/elementwise_add_kernel.h index fe6a0238dcd5249e822de3b5930438df808bf853..70334c1d3f788f60e974da74133823f82ab05765 100644 --- a/src/operators/kernel/elementwise_add_kernel.h +++ b/src/operators/kernel/elementwise_add_kernel.h @@ -30,7 +30,7 @@ class ElementwiseAddKernel : public framework::OpKernelBase { public: void Compute(const ElementwiseAddParam ¶m) const; - bool Init(const ElementwiseAddParam ¶) const; + bool Init(ElementwiseAddParam *param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/fpga/conv_kernel.cpp b/src/operators/kernel/fpga/conv_kernel.cpp index 30dd64fd1466902036a72faa4be5d359d2bdb0bf..dc537362a216983974bea325433c456136356fc8 100644 --- a/src/operators/kernel/fpga/conv_kernel.cpp +++ b/src/operators/kernel/fpga/conv_kernel.cpp @@ -20,7 +20,7 @@ namespace paddle_mobile { namespace operators { template <> -bool ConvKernel::Init(const ConvParam ¶) const { +bool ConvKernel::Init(ConvParam *param) { return true; } diff --git a/src/operators/kernel/fusion_fc_kernel.h b/src/operators/kernel/fusion_fc_kernel.h index c4e2b30176fb904d7fb906c5efc5137a5dcb8d59..0e31134ba5a18405a5855db1e85b3885608c4071 100644 --- a/src/operators/kernel/fusion_fc_kernel.h +++ b/src/operators/kernel/fusion_fc_kernel.h @@ -28,7 +28,7 @@ class FusionFcKernel : public framework::OpKernelBase { public: void Compute(const FusionFcParam& param) const; - bool Init(const FusionFcParam& para) const; + bool Init(FusionFcParam* param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/lrn_kernel.h b/src/operators/kernel/lrn_kernel.h index 40c48b3663c6825e03028439725c428ce048d254..7327451a0aa21b7bcf9ae111f63c19f2b6bb2d3a 100644 --- a/src/operators/kernel/lrn_kernel.h +++ b/src/operators/kernel/lrn_kernel.h @@ -170,7 +170,7 @@ template class LrnKernel : public framework::OpKernelBase { public: void Compute(const LrnParam ¶m) const; - bool Init(const LrnParam ¶) const; + bool Init(LrnParam *param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/mali/batchnorm_kernel.cpp b/src/operators/kernel/mali/batchnorm_kernel.cpp index ff27afc71c42ed1c2b7e67eefbdadd86e92cc0fc..22ce472c464bc9ed89ee721244e9873c01601ebd 100644 --- a/src/operators/kernel/mali/batchnorm_kernel.cpp +++ b/src/operators/kernel/mali/batchnorm_kernel.cpp @@ -128,7 +128,7 @@ class AclBatchNormOp : public acl::ACLOperator { }; template <> -bool BatchNormKernel::Init(const BatchNormParam& param) const { +bool BatchNormKernel::Init(BatchNormParam* param) { AclBatchNormOp* acl_op = reinterpret_cast*>(this->GetAclOp()); if (acl_op == nullptr) { diff --git a/src/operators/kernel/mali/conv_kernel.cpp b/src/operators/kernel/mali/conv_kernel.cpp index f3212cae970b2a554412f59cf48a6e5156463969..36f438605317dd016d2f44cf9c5efc0ab33c5923 100644 --- a/src/operators/kernel/mali/conv_kernel.cpp +++ b/src/operators/kernel/mali/conv_kernel.cpp @@ -195,7 +195,7 @@ class AclConvOp : public acl::ACLOperator { }; template <> -bool ConvKernel::Init(const ConvParam& param) const { +bool ConvKernel::Init(ConvParam* param) { AclConvOp* acl_op = reinterpret_cast*>(this->GetAclOp()); if (acl_op == nullptr) { diff --git a/src/operators/kernel/mali/elementwise_add_kernel.cpp b/src/operators/kernel/mali/elementwise_add_kernel.cpp index 43d33b3fd2b2cc747ae8c943437e675c84a4cdc6..9748bbbb5454f10ad9ea83e37d599fb1c6cdb53e 100644 --- a/src/operators/kernel/mali/elementwise_add_kernel.cpp +++ b/src/operators/kernel/mali/elementwise_add_kernel.cpp @@ -27,8 +27,7 @@ struct AddFunctor { }; template <> -bool ElementwiseAddKernel::Init( - const ElementwiseAddParam ¶) const { +bool ElementwiseAddKernel::Init(ElementwiseAddParam *param) { return true; } diff --git a/src/operators/kernel/mali/fushion_fc_kernel.cpp b/src/operators/kernel/mali/fushion_fc_kernel.cpp index 64ab07a9b955893c01e2684cba0a14fa25d032ed..a76c3c46012a758a05cf8f846a15376ad1b9f33c 100644 --- a/src/operators/kernel/mali/fushion_fc_kernel.cpp +++ b/src/operators/kernel/mali/fushion_fc_kernel.cpp @@ -22,7 +22,7 @@ namespace paddle_mobile { namespace operators { template <> -bool FusionFcKernel::Init(const FusionFcParam ¶) const { +bool FusionFcKernel::Init(FusionFcParam *param) { return true; } diff --git a/src/operators/kernel/mali/mul_kernel.cpp b/src/operators/kernel/mali/mul_kernel.cpp index f2a84deaa1de999e94e335de6d4f40981bded5a8..3a9ec4ebb319d9e521240ad987a49549c22c1ff2 100644 --- a/src/operators/kernel/mali/mul_kernel.cpp +++ b/src/operators/kernel/mali/mul_kernel.cpp @@ -22,7 +22,7 @@ namespace paddle_mobile { namespace operators { template <> -bool MulKernel::Init(const MulParam ¶) const { +bool MulKernel::Init(MulParam *param) { return true; } diff --git a/src/operators/kernel/mali/reshape_kernel.cpp b/src/operators/kernel/mali/reshape_kernel.cpp index d7521454d46dfc82064930971d2b996b542af54a..57837a677033590e92a307bd69a77c076c5ba805 100644 --- a/src/operators/kernel/mali/reshape_kernel.cpp +++ b/src/operators/kernel/mali/reshape_kernel.cpp @@ -22,7 +22,7 @@ namespace paddle_mobile { namespace operators { template <> -bool ReshapeKernel::Init(const ReshapeParam ¶) const { +bool ReshapeKernel::Init(ReshapeParam *param) { return true; } diff --git a/src/operators/kernel/mul_kernel.h b/src/operators/kernel/mul_kernel.h index 81db202c2d26fae9abb971a2cafe32f9b20dfe22..f7dcb738b38448fe38eb60dcbbd4a2abda7a858a 100644 --- a/src/operators/kernel/mul_kernel.h +++ b/src/operators/kernel/mul_kernel.h @@ -29,7 +29,7 @@ template class MulKernel : public framework::OpKernelBase { public: void Compute(const MulParam ¶m) const; - bool Init(const MulParam ¶) const; + bool Init(MulParam *param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/multiclass_nms_kernel.h b/src/operators/kernel/multiclass_nms_kernel.h index ca86604f2c6e550c219e54b6533c1500fb2912c4..9bd00b874a1140373decca582f793febf0e941ec 100644 --- a/src/operators/kernel/multiclass_nms_kernel.h +++ b/src/operators/kernel/multiclass_nms_kernel.h @@ -28,7 +28,7 @@ class MultiClassNMSKernel : public framework::OpKernelBase { public: void Compute(const MultiClassNMSParam& param) const; - bool Init(const MultiClassNMSParam& para) const; + bool Init(MultiClassNMSParam* param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/pool_kernel.h b/src/operators/kernel/pool_kernel.h index 3285f56cc01fad554bff7e6a4d25769f8ef56d24..d666910b73e7a3cef2cc59d4ba32b826ae6d0876 100644 --- a/src/operators/kernel/pool_kernel.h +++ b/src/operators/kernel/pool_kernel.h @@ -28,7 +28,7 @@ template class PoolKernel : public OpKernelBase { public: void Compute(const PoolParam ¶m) const override; - bool Init(const PoolParam ¶) const; + bool Init(PoolParam *param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/prior_box_kernel.h b/src/operators/kernel/prior_box_kernel.h index 79fc630b8efb50dec1ff336d2b66d5094eaeb5a5..d169a01d7f45f7dbdcc02be0e1e71690b8550af8 100644 --- a/src/operators/kernel/prior_box_kernel.h +++ b/src/operators/kernel/prior_box_kernel.h @@ -55,7 +55,7 @@ class PriorBoxKernel : public framework::OpKernelBase { public: void Compute(const PriorBoxParam& param) const; - bool Init(const PriorBoxParam& para) const; + bool Init(PriorBoxParam* param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/relu_kernel.h b/src/operators/kernel/relu_kernel.h index 2155c33811f553435e4a89b5b23533e2bd42db5d..64016656b20b0fdb08f1342f7853e2e727a6bb81 100644 --- a/src/operators/kernel/relu_kernel.h +++ b/src/operators/kernel/relu_kernel.h @@ -27,7 +27,7 @@ template class ReluKernel : public framework::OpKernelBase { public: void Compute(const ReluParam& param) const; - bool Init(const ReluParam& para) const; + bool Init(ReluParam* param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/reshape_kernel.h b/src/operators/kernel/reshape_kernel.h index 364f5b0902c2661017f2e72520849836f64dd0bb..47eba531b9f36d83d44588d9cdfb162519c24180 100644 --- a/src/operators/kernel/reshape_kernel.h +++ b/src/operators/kernel/reshape_kernel.h @@ -71,7 +71,7 @@ template class ReshapeKernel : public framework::OpKernelBase { public: void Compute(const ReshapeParam& param) const; - bool Init(const ReshapeParam& para) const; + bool Init(ReshapeParam* param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/sigmoid_kernel.h b/src/operators/kernel/sigmoid_kernel.h index e9eaae5ad867c6880db7346f9632ff37a92aaf66..fc3eb5e1bf158c541b2f00d9e57ddd4699344006 100644 --- a/src/operators/kernel/sigmoid_kernel.h +++ b/src/operators/kernel/sigmoid_kernel.h @@ -26,7 +26,7 @@ template class SigmoidKernel : public OpKernelBase { public: void Compute(const SigmoidParam& param) const override; - bool Init(const SigmoidParam& para) const; + bool Init(SigmoidParam* param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/softmax_kernel.h b/src/operators/kernel/softmax_kernel.h index a7a7666e32ef1923a47d71d94c93e813a23028c5..5a87d64dd9987d445b13a4fa9dc29a04e4ecc398 100644 --- a/src/operators/kernel/softmax_kernel.h +++ b/src/operators/kernel/softmax_kernel.h @@ -29,7 +29,7 @@ template class SoftmaxKernel : public OpKernelBase { public: void Compute(const SoftmaxParam ¶m) const override; - bool Init(const SoftmaxParam ¶) const; + bool Init(SoftmaxParam *param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/transpose_kernel.h b/src/operators/kernel/transpose_kernel.h index 6526d97df9863392f783841a784cb5df4e45f218..f1a21ebbb28c2acdb905ce9f09c28f0d47e17294 100644 --- a/src/operators/kernel/transpose_kernel.h +++ b/src/operators/kernel/transpose_kernel.h @@ -29,7 +29,7 @@ class TransposeKernel : public framework::OpKernelBase { public: void Compute(const TransposeParam& param) const; - bool Init(const TransposeParam& para) const; + bool Init(TransposeParam* param); }; } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp index 9c37cdea8fae1b5ec139cefbec82511ce948bff5..f74e365c7e087551e55363566d3dbd6ba530bfea 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "operators/math/depthwise_conv_3x3.h" +#include #include namespace paddle_mobile { @@ -501,6 +502,322 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, } } } +void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, Tensor filter, + Tensor *output, Tensor *bias, bool if_bias, + Tensor *new_scale, Tensor *new_bias, + bool if_bn, bool if_relu) { + const float *input_data = input->data(); + const float *filter_data = filter.data(); + float *output_data = output->data(); + const float *bias_data = bias->data(); + const float *newscale_data = new_scale->data(); + const float *newbias_data = new_bias->data(); + + const int h = static_cast(input->dims()[2]); + const int w = static_cast(input->dims()[3]); + const int l = h; + + const int batch_size = static_cast(input->dims()[0]); + const int c = static_cast(input->dims()[1]); + const int hxw = h * w; + float32x4_t vbias = vdupq_n_f32(0.0); + float32x4_t vnewbias = vdupq_n_f32(0.0); + float32x4_t vnewscale = vdupq_n_f32(1.0); + float32x4_t vzero = vdupq_n_f32(0); + + for (int b = 0; b < batch_size; ++b) { + const float *filter_data_tmp = filter_data; + + for (int j = 0; j < c; ++j) { + if (if_bias) { + vbias = vdupq_n_f32(bias_data[j]); + } + if (if_bn) { + vnewbias = vdupq_n_f32(newbias_data[j]); + vnewscale = vdupq_n_f32(newscale_data[j]); + } + int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 + float w00 = filter_data_tmp[0]; + float w01 = filter_data_tmp[1]; + float w02 = filter_data_tmp[2]; + float w10 = filter_data_tmp[3]; + float w11 = filter_data_tmp[4]; + float w12 = filter_data_tmp[5]; + float w20 = filter_data_tmp[6]; + float w21 = filter_data_tmp[7]; + float w22 = filter_data_tmp[8]; + + output_data[0] = + (w11 * input_data[0] + w12 * input_data[1] + w21 * input_data[l] + + w22 * input_data[l + 1] + bias_data[j]) * + newscale_data[j] + + newbias_data[j]; + output_data[l - 1] = (w10 * input_data[l - 2] + w11 * input_data[l - 1] + + w20 * input_data[2 * l - 2] + + w21 * input_data[2 * l - 1] + bias_data[j]) * + newscale_data[j] + + newbias_data[j]; + + output_data[(l - 1) * l] = + (w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + + w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1] + + bias_data[j]) * + newscale_data[j] + + newbias_data[j]; + output_data[l * l - 1] = (w00 * input_data[(l - 2) * (l + 1)] + + w01 * input_data[(l - 2) * (l + 1) + 1] + + w10 * input_data[l * l - 2] + + w11 * input_data[l * l - 1] + bias_data[j]) * + newscale_data[j] + + newbias_data[j]; + if (if_relu) { + output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; + output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1]; + output_data[(l - 1) * l] = + output_data[(l - 1) * l] < 0 ? 0 : output_data[(l - 1) * l]; + output_data[l * l - 1] = + output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1]; + } + for (int i = 1; i < l - 1; ++i) { + output_data[i * l] = + (w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + + w11 * input_data[i * l] + w12 * input_data[i * l + 1] + + w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1] + + bias_data[j]) * + newscale_data[j] + + newbias_data[j]; + output_data[i * l + l - 1] = + (w00 * input_data[i * l + l - 1 - l - 1] + + w01 * input_data[i * l + l - 1 - l] + + w10 * input_data[i * l + l - 1 - 1] + + w11 * input_data[i * l + l - 1] + + w20 * input_data[i * l + l - 1 + l - 1] + + w21 * input_data[i * l + l - 1 + l] + bias_data[j]) * + newscale_data[j] + + newbias_data[j]; + if (if_relu) { + output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l]; + output_data[i * l + l - 1] = + output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1]; + } + } + + // top 1 row and bottom 1 row + const float *input_tmp = input_data; + + float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, + tmp3, tmp4, tmp5, out0; + in0 = vld1q_f32(input_tmp); + in2 = vld1q_f32(input_tmp + l); + const float *input_tmp_end = input_tmp + (l - 2) * l; + in4 = vld1q_f32(input_tmp_end); + in6 = vld1q_f32(input_tmp_end + l); + int c_mid = l_mid; + auto output_ptr = output_data + 1; + for (; c_mid > 3; c_mid -= 4) { + in1 = vld1q_f32(input_tmp + 4); + in3 = vld1q_f32(input_tmp + l + 4); + + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + + out0 = vmulq_n_f32(in0, w10); + out0 = vmlaq_n_f32(out0, tmp0, w11); + out0 = vmlaq_n_f32(out0, tmp1, w12); + out0 = vmlaq_n_f32(out0, in2, w20); + out0 = vmlaq_n_f32(out0, tmp2, w21); + out0 = vmlaq_n_f32(out0, tmp3, w22); + out0 = vaddq_f32(out0, vbias); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + vst1q_f32(output_ptr, out0); + + in5 = vld1q_f32(input_tmp_end + 4); + in7 = vld1q_f32(input_tmp_end + l + 4); + + tmp0 = vextq_f32(in4, in5, 1); + tmp1 = vextq_f32(in4, in5, 2); + tmp2 = vextq_f32(in6, in7, 1); + tmp3 = vextq_f32(in6, in7, 2); + + out0 = vmulq_n_f32(in4, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in6, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vaddq_f32(out0, vbias); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + vst1q_f32(output_ptr + (l - 1) * l, out0); + + // can optimize to each 8 stride. + input_tmp += 4; + input_tmp_end += 4; + output_ptr += 4; + in0 = in1; + in2 = in3; + in4 = in5; + in6 = in7; + } + + // top right pad + float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); + + tmp0 = vextq_f32(in0, pad0, 1); + tmp1 = vextq_f32(in0, pad0, 2); + tmp2 = vextq_f32(in2, pad1, 1); + tmp3 = vextq_f32(in2, pad1, 2); + + out0 = vmulq_n_f32(in0, w10); + out0 = vmlaq_n_f32(out0, tmp0, w11); + out0 = vmlaq_n_f32(out0, tmp1, w12); + out0 = vmlaq_n_f32(out0, in2, w20); + out0 = vmlaq_n_f32(out0, tmp2, w21); + out0 = vmlaq_n_f32(out0, tmp3, w22); + out0 = vaddq_f32(out0, vbias); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, out0, 2); + } + } + + // bottom right pad + float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); + float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); + + tmp0 = vextq_f32(in4, pad2, 1); + tmp1 = vextq_f32(in4, pad2, 2); + tmp2 = vextq_f32(in6, pad3, 1); + tmp3 = vextq_f32(in6, pad3, 2); + + out0 = vmulq_n_f32(in4, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in6, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vaddq_f32(out0, vbias); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); + } + } + // mid + + for (int i = 0; i < l - 2; ++i) { + auto output_ptr = output_data + (i + 1) * l + 1; + input_tmp = input_data + i * l; + auto in0_tmp = vld1q_f32(input_tmp); + auto in2_tmp = vld1q_f32(input_tmp + l); + auto in4_tmp = vld1q_f32(input_tmp + l + l); + c_mid = l_mid; + for (; c_mid > 3; c_mid -= 4) { + auto in1_tmp = vld1q_f32(input_tmp + 4); + auto in3_tmp = vld1q_f32(input_tmp + l + 4); + auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); + + tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); + tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); + tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); + tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); + tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); + tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); + + out0 = vmulq_n_f32(in0_tmp, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vmlaq_n_f32(out0, in4_tmp, w20); + out0 = vmlaq_n_f32(out0, tmp4, w21); + out0 = vmlaq_n_f32(out0, tmp5, w22); + out0 = vaddq_f32(out0, vbias); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + vst1q_f32(output_ptr, out0); + + output_ptr += 4; + input_tmp += 4; + in0_tmp = in1_tmp; + in2_tmp = in3_tmp; + in4_tmp = in5_tmp; + } + + float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); + float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); + + tmp0 = vextq_f32(in0_tmp, pad0, 1); + tmp1 = vextq_f32(in0_tmp, pad0, 2); + tmp2 = vextq_f32(in2_tmp, pad1, 1); + tmp3 = vextq_f32(in2_tmp, pad1, 2); + tmp4 = vextq_f32(in4_tmp, pad2, 1); + tmp5 = vextq_f32(in4_tmp, pad2, 2); + + out0 = vmulq_n_f32(in0_tmp, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vmlaq_n_f32(out0, in4_tmp, w20); + out0 = vmlaq_n_f32(out0, tmp4, w21); + out0 = vmlaq_n_f32(out0, tmp5, w22); + out0 = vaddq_f32(out0, vbias); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, out0, 2); + } + } + } + output_data += hxw; + input_data += hxw; + filter_data_tmp += 9; + } + } +} } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv_3x3.h b/src/operators/math/depthwise_conv_3x3.h index ab2a04369e1fc6e984ffa6f8f5667dd2a10e2a55..44299295eebad6a90fd994cf74589c09a3573aee 100644 --- a/src/operators/math/depthwise_conv_3x3.h +++ b/src/operators/math/depthwise_conv_3x3.h @@ -32,6 +32,10 @@ void DepthwiseConv3x3(const Tensor *input, vector strides, Tensor *output, bool if_bias); void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, Tensor *output, Tensor *bias, bool if_bias); +void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, Tensor filter, + Tensor *output, Tensor *bias, bool if_bias, + Tensor *new_scale, Tensor *new_bias, + bool if_bn, bool if_relu); } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/depthwiseconv3x3s1p1.cpp b/src/operators/math/depthwiseconv3x3s1p1.cpp deleted file mode 100644 index 88cac515201c114e83cb9e85b39a51fb3f8e7955..0000000000000000000000000000000000000000 --- a/src/operators/math/depthwiseconv3x3s1p1.cpp +++ /dev/null @@ -1,288 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "operators/math/depthwiseconv3x3s1p1.h" -#include - -namespace paddle_mobile { -namespace operators { -namespace math { - -using framework::Tensor; - -void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output, - Tensor bias, bool if_bias) { - const float *input_data = input->data(); - const float *filter_data = filter.data(); - float *output_data = output->data(); - const float *bias_data = bias.data(); - - const int h = static_cast(input->dims()[2]); - const int w = static_cast(input->dims()[3]); - const int l = h; - - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const int hxw = h * w; - float32x4_t vbias = vdupq_n_f32(0.0); - for (int b = 0; b < batch_size; ++b) { - const float *filter_data_tmp = filter_data; - - for (int j = 0; j < c; ++j) { - if (if_bias) { - vbias = vdupq_n_f32(bias_data[j]); - } - - int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; - - output_data[0] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[l] + w22 * input_data[l + 1] + - bias_data[j]; - output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + - w20 * input_data[2 * l - 2] + - w21 * input_data[2 * l - 1] + bias_data[j]; - output_data[(l - 1) * l] = - w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + - w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1] + - bias_data[j]; - output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + - w01 * input_data[(l - 2) * (l + 1) + 1] + - w10 * input_data[l * l - 2] + - w11 * input_data[l * l - 1] + bias_data[j]; - - for (int i = 1; i < l - 1; ++i) { - output_data[i * l] = - w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + - w11 * input_data[i * l] + w12 * input_data[i * l + 1] + - w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1] + - bias_data[j]; - output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + - w01 * input_data[i * l + l - 1 - l] + - w10 * input_data[i * l + l - 1 - 1] + - w11 * input_data[i * l + l - 1] + - w20 * input_data[i * l + l - 1 + l - 1] + - w21 * input_data[i * l + l - 1 + l] + - bias_data[j]; - } - - // top 1 row and bottom 1 row - const float *input_tmp = input_data; - - float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, - tmp3, tmp4, tmp5, out0; - in0 = vld1q_f32(input_tmp); - in2 = vld1q_f32(input_tmp + l); - const float *input_tmp_end = input_tmp + (l - 2) * l; - in4 = vld1q_f32(input_tmp_end); - in6 = vld1q_f32(input_tmp_end + l); - int c_mid = l_mid; - auto output_ptr = output_data + 1; - for (; c_mid > 3; c_mid -= 4) { - in1 = vld1q_f32(input_tmp + 4); - in3 = vld1q_f32(input_tmp + l + 4); - - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vaddq_f32(out0, vbias); - - vst1q_f32(output_ptr, out0); - - in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + l + 4); - - tmp0 = vextq_f32(in4, in5, 1); - tmp1 = vextq_f32(in4, in5, 2); - tmp2 = vextq_f32(in6, in7, 1); - tmp3 = vextq_f32(in6, in7, 2); - - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vaddq_f32(out0, vbias); - - vst1q_f32(output_ptr + (l - 1) * l, out0); - - // can optimize to each 8 stride. - input_tmp += 4; - input_tmp_end += 4; - output_ptr += 4; - in0 = in1; - in2 = in3; - in4 = in5; - in6 = in7; - } - - // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); - - tmp0 = vextq_f32(in0, pad0, 1); - tmp1 = vextq_f32(in0, pad0, 2); - tmp2 = vextq_f32(in2, pad1, 1); - tmp3 = vextq_f32(in2, pad1, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vaddq_f32(out0, vbias); - - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - - // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); - float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); - - tmp0 = vextq_f32(in4, pad2, 1); - tmp1 = vextq_f32(in4, pad2, 2); - tmp2 = vextq_f32(in6, pad3, 1); - tmp3 = vextq_f32(in6, pad3, 2); - - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vaddq_f32(out0, vbias); - - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); - } - } - // mid - - for (int i = 0; i < l - 2; ++i) { - auto output_ptr = output_data + (i + 1) * l + 1; - input_tmp = input_data + i * l; - auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + l); - auto in4_tmp = vld1q_f32(input_tmp + l + l); - c_mid = l_mid; - for (; c_mid > 3; c_mid -= 4) { - auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + l + 4); - auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); - - tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); - tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); - tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); - tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); - tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); - tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); - - out0 = vmulq_n_f32(in0_tmp, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - out0 = vaddq_f32(out0, vbias); - - vst1q_f32(output_ptr, out0); - - output_ptr += 4; - input_tmp += 4; - in0_tmp = in1_tmp; - in2_tmp = in3_tmp; - in4_tmp = in5_tmp; - } - - float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); - - tmp0 = vextq_f32(in0_tmp, pad0, 1); - tmp1 = vextq_f32(in0_tmp, pad0, 2); - tmp2 = vextq_f32(in2_tmp, pad1, 1); - tmp3 = vextq_f32(in2_tmp, pad1, 2); - tmp4 = vextq_f32(in4_tmp, pad2, 1); - tmp5 = vextq_f32(in4_tmp, pad2, 2); - - out0 = vmulq_n_f32(in0_tmp, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - out0 = vaddq_f32(out0, vbias); - - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - } - output_data += hxw; - input_data += hxw; - filter_data_tmp += 9; - } - } -} -} // namespace math -} // namespace operators -} // namespace paddle_mobile diff --git a/src/operators/op_param.h b/src/operators/op_param.h index ad7de0ee44db3a727ec06d5fabfca203226215f4..229779a127a8dc828da1a1c7ccb3ffe188073f47 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -848,5 +848,91 @@ class FusionConvAddReluParam : public FusionConvAddParam { }; #endif +#ifdef FUSION_CONVADDBNRELU_OP +class FusionConvAddBNReluParam : public OpParam { + public: + FusionConvAddBNReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + bias_ = InputYFrom(inputs, scope); + axis_ = GetAttr("axis", attrs); + filter_ = FilterFrom(inputs, scope); + input_ = InputFrom(inputs, scope); + output_ = OutFrom(outputs, scope); + strides_ = GetAttr>("strides", attrs); + paddings_ = GetAttr>("paddings", attrs); + dilations_ = GetAttr>("dilations", attrs); + groups = GetAttr("groups", attrs); + input_bias_ = InputBiasFrom(inputs, scope); + input_mean_ = InputMeanFrom(inputs, scope); + input_scale_ = InputScaleFrom(inputs, scope); + input_variance_ = InputVarianceFrom(inputs, scope); + epsilon_ = GetAttr("epsilon", attrs); + momentum_ = GetAttr("momentum", attrs); + is_test_ = GetAttr("is_test", attrs); + } + Tensor *Bias() const { return bias_; } + + const int &Axis() const { return axis_; } + + const Tensor *Input() const { return input_; } + + const Tensor *Filter() const { return filter_; } + + Tensor *Output() const { return output_; } + + const vector &Strides() const { return strides_; } + + const vector &Paddings() const { return paddings_; } + + const vector &Dilations() const { return dilations_; } + + const int &Groups() const { return groups; } + + const Tensor *InputBias() const { return input_bias_; } + + const Tensor *InputMean() const { return input_mean_; } + + const Tensor *InputScale() const { return input_scale_; } + + const Tensor *InputVariance() const { return input_variance_; } + + const float &Epsilon() const { return epsilon_; } + + const float &Momentum() const { return momentum_; } + + const bool &IsTest() const { return is_test_; } + + void SetNewScale(Tensor *new_scale) { new_scale_ = new_scale; } + + void SetNewBias(Tensor *new_bias) { new_bias_ = new_bias; } + + const Tensor *NewScale() const { return new_scale_; } + + const Tensor *NewBias() const { return new_bias_; } + + protected: + Tensor *bias_; + int axis_; + Tensor *input_; + Tensor *output_; + Tensor *filter_; + vector strides_; + vector paddings_; + vector dilations_; + int groups; + Tensor *input_bias_; + Tensor *input_mean_; + Tensor *input_scale_; + Tensor *input_variance_; + float epsilon_; + float momentum_; + bool is_test_; + Tensor *new_bias_; + Tensor *new_scale_; +}; + +Print &operator<<(Print &printer, const FusionConvAddParam &conv_param); +#endif } // namespace operators } // namespace paddle_mobile diff --git a/tools/build.sh b/tools/build.sh index 42e872c580cffef3bd904dc9cc575e9961ef4257..86ae8b5e1aa16c7cab66580bc2eaa6e1e526fc17 100755 --- a/tools/build.sh +++ b/tools/build.sh @@ -31,7 +31,7 @@ build_for_mac() { } build_for_android() { - rm -rf "../build" + #rm -rf "../build" if [ -z "${ANDROID_NDK}" ]; then echo "ANDROID_NDK not found!" exit -1 diff --git a/tools/op.cmake b/tools/op.cmake index 2eabac925f6021448243b3668c22cbcaebe2f1d9..c7d1840bc552d6b45fe2fcc9f8d19e6784598ee6 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -22,6 +22,7 @@ elseif (NET EQUAL "mobilenet") set(BATCHNORM_OP ON) set(POOL_OP ON) set(RESHAPE_OP ON) + set(FUSION_CONVADDBNRELU_OP) elseif (NET EQUAL "yolo") set(BATCHNORM_OP ON) set(CONV_OP ON) @@ -64,6 +65,8 @@ else () set(SOFTMAX_OP ON) set(TRANSPOSE_OP ON) set(FUSION_CONVADD_RELU_OP ON) + set(FUSION_CONVADDBNRELU_OP ON) + # option(BATCHNORM_OP "" ON) # option(BOXCODER_OP "" ON) # option(CONCAT_OP "" ON) @@ -145,4 +148,7 @@ if (TRANSPOSE_OP) endif() if (FUSION_CONVADD_RELU_OP) add_definitions(-DFUSION_CONVADD_RELU_OP) -endif() \ No newline at end of file +endif() +if (FUSION_CONVADDBNRELU_OP) + add_definitions(-DFUSION_CONVADDBNRELU_OP) +endif()