diff --git a/src/operators/fusion_fc_relu_op.cpp b/src/operators/fusion_fc_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e90c4d05b692d5548c4e45bdf33a3cd836bf9d09 --- /dev/null +++ b/src/operators/fusion_fc_relu_op.cpp @@ -0,0 +1,68 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_FC_RELU_OP + +#include "operators/fusion_fc_relu_op.h" +namespace paddle_mobile { + namespace operators { + + template + void FusionFcReluOp::InferShape() const { + auto x_dims = this->param_.InputX()->dims(); + auto y_dims = this->param_.InputY()->dims(); + int x_num_col_dims = this->param_.XNumColDims(); + int y_num_col_dims = this->param_.YNumColDims(); + + assert(x_dims.size() > x_num_col_dims); + assert(y_dims.size() > y_num_col_dims); + + /// (1,2,3,4) , x_num_col_dims = 2 -> (2,12) + auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); + auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); + + assert(x_mat_dims[1] == y_mat_dims[0]); + + std::vector output_dims; + output_dims.reserve( + static_cast(x_num_col_dims + y_dims.size() - y_num_col_dims)); + + for (int i = 0; i < x_num_col_dims; ++i) { + output_dims.push_back(x_dims[i]); + } + + for (int i = y_num_col_dims; i < y_dims.size(); ++i) { + output_dims.push_back(y_dims[i]); + } + + framework::DDim ddim = framework::make_ddim(output_dims); + this->param_.Out()->Resize(ddim); + } + + } // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_fc_relu, ops::FusionFcReluOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +REGISTER_OPERATOR_MALI_GPU(fusion_fc_relu, ops::FusionFcReluOp); +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif + + diff --git a/src/operators/fusion_fc_relu_op.h b/src/operators/fusion_fc_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..69fec4af29fe8c78f113a8e25933bc4dfe9974f7 --- /dev/null +++ b/src/operators/fusion_fc_relu_op.h @@ -0,0 +1,98 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef FUSION_FC_RELU_OP +#pragma once +#include +#include + +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/fusion_fc_relu_kernel.h" + +namespace paddle_mobile { +namespace operators { +using std::string; +using std::vector; +class FusionFcReluMatcher : public framework::FusionOpMatcher { + public: + FusionFcReluMatcher() { + node_ = framework::Node(G_OP_TYPE_MUL); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Z"}}}}, removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_FC_RELU; } +}; + +template +class FusionFcReluOp + : public framework::OperatorWithKernel< + DeviceType, FusionFcReluParam, operators::FusionFcReluKernel> { + public: + FusionFcReluOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel>( + type, inputs, outputs, attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, FusionFcReluParam, + operators::FusionFcReluKernel>::OperatorWithKernel; + void InferShape() const override; + + protected: +}; + +#ifdef PADDLE_MOBILE_CPU + +#ifndef FUSION_FC_RELU_REGISTER +#define FUSION_FC_RELU_REGISTER +static framework::FusionOpRegistrar fc_relu_registrar(new FusionFcReluMatcher()); +#endif + +#endif + +#ifdef PADDLE_MOBILE_MALI_GPU + +#ifndef FUSION_FC_RELU_REGISTER +#define FUSION_FC_RELU_REGISTER +static framework::FusionOpRegistrar fc_relu_registrar(new FusionFcReluMatcher()); +#endif + +#endif + +#ifdef PADDLE_MOBILE_FPGA +#endif + +} // namespace operators +} // namespace paddle_mobile + +#ifdef PADDLE_MOBILE_CPU +USE_OP_CPU(fusion_fc_relu); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +USE_OP_MALI_GPU(fusion_fc_relu); +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif +#endif //FUSION_FC_RELU_OP diff --git a/src/operators/kernel/fusion_fc_relu_kernal.h b/src/operators/kernel/fusion_fc_relu_kernal.h new file mode 100644 index 0000000000000000000000000000000000000000..9f9e84d5b0d3559566986bdbecc2ddf36d73ccb4 --- /dev/null +++ b/src/operators/kernel/fusion_fc_relu_kernal.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_FC_RELU_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/math/math_function.h" +#include "operators/op_param.h" + +namespace paddle_mobile { + namespace operators { + + template + class FusionFcReluKernel + : public framework::OpKernelBase { + public: + void Compute(const FusionFcReluParam& param) const; + bool Init(FusionFcReluParam* param); + }; + } // namespace operators +} // namespace paddle_mobile + +#endif +