提交 1356a19a 编写于 作者: E eclipsess

Merge remote-tracking branch 'upstream/develop' into develop

...@@ -18,37 +18,75 @@ SOFTWARE. ...@@ -18,37 +18,75 @@ SOFTWARE.
#pragma once #pragma once
#include "framework/operator.h" #include <map>
#include "operators/kernel/pool_kernel.h"
namespace paddle_mobile { #include "attribute.h"
namespace operators { #include "block_desc.h"
#include "common/type_define.h"
#include "common/types.h"
#include "common/variant.h"
#include "op_info.h"
#include "op_kernel_type.h"
#include "paddle_mobile_object.h"
#include "scope.h"
#include "tensor.h"
#include "variable.h"
using namespace framework; namespace paddle_mobile {
namespace framework {
template <typename DeviceType, typename T> template <typename Dtype> class OperatorBase : PaddleMobileObject {
class ConvOp : public framework::OperatorWithKernel<DeviceType> {
public: public:
ConvOp(const std::string &type, const VariableNameMap &inputs, OperatorBase(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<Scope> scope);
: framework::OperatorWithKernel<DeviceType>( virtual ~OperatorBase() {}
type, inputs, outputs, attrs, scope), virtual void Run() const = 0;
param_(inputs, outputs, attrs, *scope) {}
const VariableNameMap &Inputs() const { return inputs_; }
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel; const VariableNameMap &Outputs() const { return outputs_; }
void InferShape() const override; const std::string &Type() const { return type_; }
const AttributeMap &Attrs() const { return attrs_; }
void Run() const { void ClearVariables() const {
operators::ConvKernel<DeviceType, T, ConvParam> kernel; if (this->scope_) {
kernel.Compute(param_); this->scope_->EraseVars(this->inputs_.at("Filter"));
this->ClearVariables(); this->scope_->EraseVars(this->inputs_.at("Input"));
}
} }
protected:
std::shared_ptr<Scope> scope_;
std::string type_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
private: private:
ConvParam param_; void CheckAllInputOutputSet() const;
};
template <typename Dtype>
class OperatorWithKernel : public OperatorBase<Dtype> {
public:
OperatorWithKernel(const std::string &type,
const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs,
std::shared_ptr<Scope> scope)
: OperatorBase<Dtype>(type, inputs, outputs, attrs, scope) {}
virtual void InferShape() const = 0;
virtual void Run() const = 0;
};
template <typename Dtype, typename P>
class OpKernelBase : PaddleMobileObject {
public:
virtual void Compute(const P &para) const = 0;
virtual ~OpKernelBase() = default;
}; };
} // operators } // namespace framework
} // paddle_mobile } // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册