diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index 7e8b9ce48e30555b5a66e6bfecffb618c83e8e41..d7f840ea3e830b63c94c63752115a71463a34c1c 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -32,9 +32,6 @@ namespace paddle_mobile { scope_(scope) { CheckAllInputOutputSet(); } - - template void OperatorBase::Run() { RunImpl(); } - template void OperatorBase::CheckAllInputOutputSet() const {} diff --git a/src/framework/operator.h b/src/framework/operator.h index 782dfd0b79b8c12ffd39fe4a990fd48637689e38..f02771223cb5e14acdb5cbeaaeefb4312470296f 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -18,72 +18,37 @@ SOFTWARE. #pragma once -#include - -#include "attribute.h" -#include "block_desc.h" -#include "common/type_define.h" -#include "common/types.h" -#include "common/variant.h" -#include "op_info.h" -#include "op_kernel_type.h" -#include "paddle_mobile_object.h" -#include "scope.h" -#include "tensor.h" -#include "variable.h" +#include "framework/operator.h" +#include "operators/kernel/pool_kernel.h" namespace paddle_mobile { - namespace framework { - - template class OperatorBase : PaddleMobileObject { - public: - OperatorBase(const std::string &type, const VariableNameMap &inputs, - const VariableNameMap &outputs, - const AttributeMap &attrs, - std::shared_ptr scope); - virtual ~OperatorBase() {} - virtual void Run(); - const VariableNameMap &Inputs() const { return inputs_; } - const VariableNameMap &Outputs() const { return outputs_; } - const std::string &Type() const { return type_; } - const AttributeMap &Attrs() const { return attrs_; } + namespace operators { - protected: - std::shared_ptr scope_; - std::string type_; - VariableNameMap inputs_; - VariableNameMap outputs_; - AttributeMap attrs_; - - private: - void CheckAllInputOutputSet() const; - virtual void RunImpl() const = 0; - }; + using namespace framework; - template - class OperatorWithKernel : public OperatorBase { + template + class ConvOp : public framework::OperatorWithKernel { public: - OperatorWithKernel(const std::string &type, - const VariableNameMap &inputs, - const VariableNameMap &outputs, - const AttributeMap &attrs, - std::shared_ptr scope) - : OperatorBase(type, inputs, outputs, attrs, scope) {} - virtual void InferShape() const = 0; - - protected: - virtual void RunImpl() const = 0; + ConvOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel( + type, inputs, outputs, attrs, scope), + param_(inputs, outputs, attrs, *scope) {} + + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape() const override; + + void Run() const { + operators::ConvKernel kernel; + kernel.Compute(param_); + this->ClearVariables(); + } private: + ConvParam param_; }; - template - class OpKernelBase : PaddleMobileObject { - public: - virtual void Compute(const P ¶) const = 0; - - virtual ~OpKernelBase() = default; - }; - - } // namespace framework -} // namespace paddle_mobile + } // operators +} // paddle_mobile diff --git a/src/framework/variable.h b/src/framework/variable.h index 55b1bb0449fb07728a23de325967822e94100275..b80b3f5d27bf81a981023341f6c5ff74042d14c0 100644 --- a/src/framework/variable.h +++ b/src/framework/variable.h @@ -29,9 +29,6 @@ namespace paddle_mobile { namespace framework { class Variable : public PaddleMobileObject { public: - Variable() {} - ~Variable() {} - template const T *Get() const { return static_cast(holder_->Ptr()); } diff --git a/src/operators/conv_op.h b/src/operators/conv_op.h index 07ac49b555af9db2906d9b537729dea4e90bcad9..4b2db961cd2872f3f8fc25f1763086a041b99427 100644 --- a/src/operators/conv_op.h +++ b/src/operators/conv_op.h @@ -40,12 +40,13 @@ namespace paddle_mobile { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape() const override; - protected: - void RunImpl() const { + void Run() const { operators::ConvKernel kernel; kernel.Compute(param_); + this->ClearVariables(); } + private: ConvParam param_; };