diff --git a/src/framework/operator.h b/src/framework/operator.h index c85a38d73c620ae4b08387b548bd2f4f8ca71711..c68744a676030413e81570ded0db5671cdf4ba7a 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -138,9 +138,21 @@ class OpKernelBase { * @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体, * 所有结构体存在与: paddle-mobile/src/operators/op_param.h * */ +#ifdef PADDLE_MOBILE_MALI_GPU + OpKernelBase() { acl_op_ = nullptr; } + void *GetAclOp() const { return acl_op_; } + void SetAclOp(void *op, void *ob) const { + reinterpret_cast *>(ob)->acl_op_ = op; + } +#endif virtual void Compute(const P ¶) const = 0; virtual bool Init(const P ¶) const { return true; }; virtual ~OpKernelBase() = default; + + private: +#ifdef PADDLE_MOBILE_MALI_GPU + void *acl_op_; +#endif }; #define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \ diff --git a/src/framework/program/program-optimize/node.cpp b/src/framework/program/program-optimize/node.cpp index 4ea45ec0a859ef8aa3ab4e34de8279e732706803..e635e07eaf4484c3e390101c3b43fdaf24bbd2c6 100644 --- a/src/framework/program/program-optimize/node.cpp +++ b/src/framework/program/program-optimize/node.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "framework/program/program-optimize/node.h" +#include #include "framework/operator.h" namespace paddle_mobile { @@ -92,7 +93,8 @@ int Node::Depth(int begin) { Node &Node::Folder( int size, std::string type, - std::map> change, + std::map>> + change, std::vector> *removed_nodes) { std::shared_ptr op_desc = std::make_shared(); @@ -109,12 +111,15 @@ Node &Node::Folder( void Node::Folder( std::shared_ptr op_desc, std::vector> *outputs, int index, - std::map> *change, + std::map>> + *change, Node *begin_node, std::vector> *removed_nodes) { if (change->find(this->type_) != change->end()) { - auto change_pair = (*change)[this->type_]; - op_desc->GetInputs()[change_pair.second] = - this->op_desc_->GetInputs()[change_pair.first]; + auto change_pairs = (*change)[this->type_]; + for (const auto &change_pair : change_pairs) { + op_desc->GetInputs()[change_pair.second] = + this->op_desc_->GetInputs()[change_pair.first]; + } } for (auto &attr_pair : this->op_desc_->attrs_) { diff --git a/src/framework/program/program-optimize/node.h b/src/framework/program/program-optimize/node.h index 7236ffdd1782dfb39af73195da9b3756030c9117..88bf1e16ed2a5fb3a038eadd546d63ffb3916f68 100644 --- a/src/framework/program/program-optimize/node.h +++ b/src/framework/program/program-optimize/node.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include "common/log.h" #include "framework/program/op_desc.h" @@ -43,7 +44,8 @@ class Node { int Depth(int begin = 0); Node &Folder( int size, std::string type, - std::map> change_map, + std::map>> + change, std::vector> *removed_nodes); std::vector> OpDescs(int size); std::shared_ptr OpDescOfNode() { return op_desc_; } @@ -56,7 +58,8 @@ class Node { void Folder( std::shared_ptr op_desc, std::vector> *outputs, int index, - std::map> *change, + std::map>> + *change, Node *begin_node, std::vector> *removed_nodes); std::shared_ptr op_desc_; #ifdef PADDLE_MOBILE_DEBUG diff --git a/src/operators/feed_op.h b/src/operators/feed_op.h index 034cf947871a962b786b66e3752d86f5a327f342..bd5fd8cb32d484b7f76652139603f6b0f1b4b5d7 100644 --- a/src/operators/feed_op.h +++ b/src/operators/feed_op.h @@ -50,6 +50,8 @@ USE_OP_CPU(feed); REGISTER_OPERATOR_CPU(feed, ops::FeedOp); #endif #ifdef PADDLE_MOBILE_MALI_GPU +USE_OP_MALI_GPU(feed); +REGISTER_OPERATOR_MALI_GPU(feed, ops::FeedOp); #endif #ifdef PADDLE_MOBILE_FPGA #endif diff --git a/src/operators/fetch_op.h b/src/operators/fetch_op.h index c28424f0d1880c9f7f44c6644a163215d639f7a3..4b3680b58357d8295b1b6acf111d3573d4e4d1bd 100644 --- a/src/operators/fetch_op.h +++ b/src/operators/fetch_op.h @@ -50,6 +50,8 @@ USE_OP_CPU(fetch); REGISTER_OPERATOR_CPU(fetch, ops::FetchOp); #endif #ifdef PADDLE_MOBILE_MALI_GPU +USE_OP_MALI_GPU(fetch); +REGISTER_OPERATOR_MALI_GPU(fetch, ops::FetchOp); #endif #ifdef PADDLE_MOBILE_FPGA #endif diff --git a/src/operators/fusion_conv_add.cpp b/src/operators/fusion_conv_add.cpp index 2605414c892f89787701334f428621d9d8c2520f..4c01603509b0a1d9da2c2dc31a38719d5117e05c 100644 --- a/src/operators/fusion_conv_add.cpp +++ b/src/operators/fusion_conv_add.cpp @@ -54,6 +54,8 @@ USE_OP_CPU(conv_add); REGISTER_OPERATOR_CPU(conv_add, ops::FusionConvAddOp); #endif #ifdef PADDLE_MOBILE_MALI_GPU +USE_OP_MALI_GPU(conv_add); +REGISTER_OPERATOR_MALI_GPU(conv_add, ops::FusionConvAddOp); #endif #ifdef PADDLE_MOBILE_FPGA #endif diff --git a/src/operators/fusion_conv_add.h b/src/operators/fusion_conv_add.h index 24f1d3f63b3300db9b60a595466a0ced3b9e996b..ba5bb89b8eb4f1e1831f4f5ef83cfdccad68ab9f 100644 --- a/src/operators/fusion_conv_add.h +++ b/src/operators/fusion_conv_add.h @@ -40,7 +40,7 @@ class FusionConvAddMatcher : public framework::FusionOpMatcher { vector> origin_descs = node->OpDescs(node_.Depth()); node->Folder(node_.Depth(), Type(), - {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Y"}}}, removed_nodes); + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes); } std::string Type() { return G_OP_TYPE_CONV_ADD; } @@ -68,11 +68,13 @@ class FusionConvAddOp : public framework::OperatorWithKernel< }; #ifdef PADDLE_MOBILE_CPU + #ifndef CONV_ADD_REGISTER static framework::FusionOpRegistrar convadd_registrar( new FusionConvAddMatcher()); #define CONV_ADD_REGISTER #endif + #endif #ifdef PADDLE_MOBILE_MALI_GPU diff --git a/src/operators/fusion_conv_add_relu_op.h b/src/operators/fusion_conv_add_relu_op.h index fd27005c8bef8f8cb91fbf5b6e5a852306c28a9b..bcacb3da3e2ec5371021f3552ffd2c9f53947874 100644 --- a/src/operators/fusion_conv_add_relu_op.h +++ b/src/operators/fusion_conv_add_relu_op.h @@ -36,7 +36,7 @@ class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher { framework::Node *node, std::vector> *removed_nodes) { node->Folder(node_.Depth(), Type(), - {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Y"}}}, removed_nodes); + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes); } std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_RELU; } }; @@ -65,11 +65,11 @@ class FusionConvAddReluOp : public framework::OperatorWithKernel< #ifdef PADDLE_MOBILE_CPU -#ifndef CONV_ADD_RELU_REGISTER -#define CONV_ADD_RELU_REGISTER +//#ifndef CONV_ADD_RELU_REGISTER +//#define CONV_ADD_RELU_REGISTER // static framework::FusionOpRegistrar fusion_conv_add_relu_registrar(new // FusionConvAddReluOpMatcher()); -#endif +//#endif #endif #ifdef PADDLE_MOBILE_MALI_GPU diff --git a/src/operators/fusion_fc_op.h b/src/operators/fusion_fc_op.h index 0ca4d2b27ad46b77ddba55b6b377e741c97bdc9e..ea1f42f0adfb532982f50c2da41fc58f63b54834 100644 --- a/src/operators/fusion_fc_op.h +++ b/src/operators/fusion_fc_op.h @@ -38,7 +38,7 @@ class FusionFcMatcher : public framework::FusionOpMatcher { framework::Node *node, std::vector> *removed_nodes) { node->Folder(node_.Depth(), Type(), - {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}, removed_nodes); + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Z"}}}}, removed_nodes); } std::string Type() { return G_OP_TYPE_FC; } @@ -66,17 +66,21 @@ class FusionFcOp }; #ifdef PADDLE_MOBILE_CPU + #ifndef CONV_CPU_REGISTER #define CONV_CPU_REGISTER static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher()); #endif + #endif #ifdef PADDLE_MOBILE_MALI_GPU + #ifndef CONV_CPU_REGISTER #define CONV_CPU_REGISTER static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher()); #endif + #endif #ifdef PADDLE_MOBILE_FPGA