/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include "common/enforce.h" #include "common/type_define.h" #include "common/types.h" #include "common/variant.h" #include "framework/attribute.h" #include "framework/op_info.h" #include "framework/op_kernel_type.h" #include "framework/op_registry.h" #include "framework/program/block_desc.h" #include "framework/program/program-optimize/node.h" #include "framework/scope.h" #include "framework/tensor.h" #include "framework/variable.h" #ifdef PADDLE_MOBILE_CL #include "framework/cl/cl_helper.h" #include "framework/cl/cl_scope.h" #endif namespace paddle_mobile { namespace framework { template static T *GetVarValue(const std::string &key, const VariableNameMap &var_map, const Scope &scope) { auto var_vec = var_map.at(key); if (!var_vec.empty()) { auto var = scope.FindVar(var_vec[0]); return var->GetMutable(); } else { return nullptr; } } template class OperatorBase { public: OperatorBase(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, std::shared_ptr scope); virtual ~OperatorBase() {} virtual void Init() = 0; virtual void InferShape() const = 0; virtual void Run(); virtual void RunImpl() = 0; std::vector GetOutKeys() const; std::vector GetInputKeys() const; const VariableNameMap &Inputs() const { return inputs_; } const VariableNameMap &Outputs() const { return outputs_; } const std::string &Type() const { return type_; } const AttributeMap &Attrs() const { return attrs_; } void ClearVariables(const std::vector &var_names) const { if (this->scope_) { this->scope_->EraseVars(var_names); } } protected: std::shared_ptr scope_; std::string type_; VariableNameMap inputs_; VariableNameMap outputs_; AttributeMap attrs_; private: void CheckAllInputOutputSet() const; }; template class OperatorWithKernel : public OperatorBase { public: OperatorWithKernel(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, std::shared_ptr scope) : OperatorBase(type, inputs, outputs, attrs, scope), param_(inputs, outputs, attrs, *scope) { #ifdef PADDLE_MOBILE_CL kernel_.InitCLHelper(scope->GetCLScpoe()); #endif } virtual void RunImpl() { this->kernel_.Compute(this->param_); } virtual void InferShape() const = 0; void Init() { PADDLE_MOBILE_ENFORCE(kernel_.Init(¶m_), " %s kernel init failed", this->type_.c_str()); } protected: KernelType kernel_; ParamType param_; }; template class OpKernelBase { public: OpKernelBase() = default; #ifdef PADDLE_MOBILE_CL virtual void InitCLHelper(CLScope *clScope) { cl_helper_ = CLHelper(clScope); } #endif #ifdef PADDLE_McOBILE_MALI_GPU OpKernelBase() { acl_op_ = nullptr; } void *GetAclOp() const { return acl_op_; } void SetAclOp(void *op, void *ob) const { reinterpret_cast *>(ob)->acl_op_ = op; } #endif virtual void Compute(const P ¶) = 0; virtual bool Init(P *para) { return true; } virtual ~OpKernelBase() = default; protected: #ifdef PADDLE_MOBILE_CL CLHelper cl_helper_; #endif private: #ifdef PADDLE_MOBILE_MALI_GPU void *acl_op_; #endif }; class FusionOpMatcher { public: FusionOpMatcher() {} virtual std::string Type() = 0; virtual void FolderNodes( Node *node, std::vector> *removed_nodes) { node->Folder(node_.Depth(), Type(), {}, removed_nodes); } virtual Node &BeginNode() { return node_; } std::string BeginType() { return node_.Type(); } virtual std::vector> NeedCheck() { return {}; } protected: Node node_; std::string type_; std::shared_ptr new_opdesc_; }; #define DECLARE_OPERATOR(OpName, OpParam, OpKernel) \ template \ class OpName##Op : public framework::OperatorWithKernel< \ DeviceType, OpParam, \ operators::OpKernel> { \ public: \ OpName##Op(const std::string &type, const VariableNameMap &inputs, \ const VariableNameMap &outputs, \ const framework::AttributeMap &attrs, \ std::shared_ptr scope) \ : framework::OperatorWithKernel, \ operators::OpKernel>( \ type, inputs, outputs, attrs, scope) {} \ \ void InferShape() const override; \ }; #define DECLARE_KERNEL(OpName, OpParam) \ template \ class OpName##Kernel \ : public framework::OpKernelBase> { \ public: \ bool Init(OpParam *param); \ void Compute(const OpParam ¶m); \ }; #define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \ cls(const std::string &type, const ::paddle_mobile::VariableNameMap &inputs, \ const ::paddle_mobile::VariableNameMap &outputs, \ const ::paddle_mobile::framework::AttributeMap &attrs, \ std::shared_ptr<::paddle_mobile::framework::Scope> scope) \ : parent_cls(type, inputs, outputs, attrs, scope) {} } // namespace framework } // namespace paddle_mobile