提交 f83876a0 编写于 作者: Y Yi Wang

Add constructors to OperatorBase and all sub-classes

上级 d08b9538
...@@ -30,6 +30,11 @@ using DeviceContext = platform::DeviceContext; ...@@ -30,6 +30,11 @@ using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase { class EmptyOp : public OperatorBase {
public: public:
EmptyOp(const std::string &type, const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs, const AttributeMap &attrs,
std::unordered_map<std::string, int> *in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
void InferShape(const Scope &scope) const override {} void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {} void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
}; };
......
...@@ -10,6 +10,11 @@ namespace framework { ...@@ -10,6 +10,11 @@ namespace framework {
class NOP : public OperatorBase { class NOP : public OperatorBase {
public: public:
NOP(const std::string &type, const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs, const AttributeMap &attrs,
std::unordered_map<std::string, int> *in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
void InferShape(const Scope &scope) const override {} void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {} const platform::DeviceContext &dev_ctx) const override {}
......
...@@ -7,6 +7,11 @@ namespace paddle { ...@@ -7,6 +7,11 @@ namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
CosineOp(const std::string& type, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
...@@ -27,6 +32,11 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -27,6 +32,11 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
MyTestOp(const std::string& type, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
......
...@@ -63,6 +63,16 @@ class ExecutionContext; ...@@ -63,6 +63,16 @@ class ExecutionContext;
*/ */
class OperatorBase { class OperatorBase {
public: public:
OperatorBase(const std::string& type, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: type_(type),
inputs_(input),
outputs_(output),
attrs_(attrs),
in_out_idxs_(in_out_idxs) {}
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
...@@ -109,6 +119,9 @@ class OperatorBase { ...@@ -109,6 +119,9 @@ class OperatorBase {
const std::vector<std::string> Inputs() const { return inputs_; } const std::vector<std::string> Inputs() const { return inputs_; }
const std::vector<std::string> Outputs() const { return outputs_; } const std::vector<std::string> Outputs() const { return outputs_; }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
const std::unordered_map<std::string, int>* InOutIdx() const {
return in_out_idxs_.get();
}
public: public:
std::string type_; std::string type_;
...@@ -286,6 +299,13 @@ class OpKernel { ...@@ -286,6 +299,13 @@ class OpKernel {
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
OperatorWithKernel(const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
struct OpKernelKey { struct OpKernelKey {
platform::Place place_; platform::Place place_;
......
...@@ -23,6 +23,13 @@ static int op_run_num = 0; ...@@ -23,6 +23,13 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase { class OpWithoutKernelTest : public OperatorBase {
public: public:
OpWithoutKernelTest(const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
...@@ -116,6 +123,13 @@ class CPUKernelTest : public OpKernel { ...@@ -116,6 +123,13 @@ class CPUKernelTest : public OpKernel {
// multiple inputs test // multiple inputs test
class OperatorMultiInputsTest : public OperatorBase { class OperatorMultiInputsTest : public OperatorBase {
public: public:
OperatorMultiInputsTest(const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册