diff --git a/src/framework/operator.h b/src/framework/operator.h index f02771223cb5e14acdb5cbeaaeefb4312470296f..da22bf48abcaf11f5c21d0353f2a47189901f5d9 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -18,37 +18,75 @@ SOFTWARE. #pragma once -#include "framework/operator.h" -#include "operators/kernel/pool_kernel.h" +#include -namespace paddle_mobile { - namespace operators { +#include "attribute.h" +#include "block_desc.h" +#include "common/type_define.h" +#include "common/types.h" +#include "common/variant.h" +#include "op_info.h" +#include "op_kernel_type.h" +#include "paddle_mobile_object.h" +#include "scope.h" +#include "tensor.h" +#include "variable.h" - using namespace framework; +namespace paddle_mobile { + namespace framework { - template - class ConvOp : public framework::OperatorWithKernel { + template class OperatorBase : PaddleMobileObject { public: - ConvOp(const std::string &type, const VariableNameMap &inputs, - const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) - : framework::OperatorWithKernel( - type, inputs, outputs, attrs, scope), - param_(inputs, outputs, attrs, *scope) {} - - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape() const override; - - void Run() const { - operators::ConvKernel kernel; - kernel.Compute(param_); - this->ClearVariables(); + OperatorBase(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, + std::shared_ptr scope); + virtual ~OperatorBase() {} + virtual void Run() const = 0; + + const VariableNameMap &Inputs() const { return inputs_; } + const VariableNameMap &Outputs() const { return outputs_; } + const std::string &Type() const { return type_; } + const AttributeMap &Attrs() const { return attrs_; } + void ClearVariables() const { + if (this->scope_) { + this->scope_->EraseVars(this->inputs_.at("Filter")); + this->scope_->EraseVars(this->inputs_.at("Input")); + } } + protected: + std::shared_ptr scope_; + std::string type_; + VariableNameMap inputs_; + VariableNameMap outputs_; + AttributeMap attrs_; + private: - ConvParam param_; + void CheckAllInputOutputSet() const; + }; + + template + class OperatorWithKernel : public OperatorBase { + public: + OperatorWithKernel(const std::string &type, + const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, + std::shared_ptr scope) + : OperatorBase(type, inputs, outputs, attrs, scope) {} + virtual void InferShape() const = 0; + + virtual void Run() const = 0; + }; + + template + class OpKernelBase : PaddleMobileObject { + public: + virtual void Compute(const P ¶) const = 0; + + virtual ~OpKernelBase() = default; }; - } // operators -} // paddle_mobile + } // namespace framework +} // namespace paddle_mobile