From f83876a015a779ca5b9575e80a67d4a08ac94284 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Fri, 11 Aug 2017 11:31:10 -0700 Subject: [PATCH] Add constructors to OperatorBase and all sub-classes --- paddle/framework/backward_test.cc | 5 +++++ paddle/framework/grad_op_builder_test.cc | 5 +++++ paddle/framework/op_registry_test.cc | 10 ++++++++++ paddle/framework/operator.h | 20 ++++++++++++++++++++ paddle/framework/operator_test.cc | 14 ++++++++++++++ 5 files changed, 54 insertions(+) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 1677a3ed4..b930b86ed 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -30,6 +30,11 @@ using DeviceContext = platform::DeviceContext; class EmptyOp : public OperatorBase { public: + EmptyOp(const std::string &type, const std::vector &inputs, + const std::vector &outputs, const AttributeMap &attrs, + std::unordered_map *in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void InferShape(const Scope &scope) const override {} void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {} }; diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index f1ebbae52..c3ce69a34 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -10,6 +10,11 @@ namespace framework { class NOP : public OperatorBase { public: + NOP(const std::string &type, const std::vector &inputs, + const std::vector &outputs, const AttributeMap &attrs, + std::unordered_map *in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void InferShape(const Scope &scope) const override {} void Run(const Scope &scope, const platform::DeviceContext &dev_ctx) const override {} diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 9894928a7..de3435ad3 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -7,6 +7,11 @@ namespace paddle { namespace framework { class CosineOp : public OperatorBase { public: + CosineOp(const std::string& type, const std::vector& inputs, + const std::vector& outputs, const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} void InferShape(const Scope& scope) const override {} @@ -27,6 +32,11 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: + MyTestOp(const std::string& type, const std::vector& inputs, + const std::vector& outputs, const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f5d167a16..8b7f74367 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -63,6 +63,16 @@ class ExecutionContext; */ class OperatorBase { public: + OperatorBase(const std::string& type, const std::vector& inputs, + const std::vector& outputs, + const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : type_(type), + inputs_(input), + outputs_(output), + attrs_(attrs), + in_out_idxs_(in_out_idxs) {} + virtual ~OperatorBase() {} template @@ -109,6 +119,9 @@ class OperatorBase { const std::vector Inputs() const { return inputs_; } const std::vector Outputs() const { return outputs_; } const AttributeMap& Attrs() const { return attrs_; } + const std::unordered_map* InOutIdx() const { + return in_out_idxs_.get(); + } public: std::string type_; @@ -286,6 +299,13 @@ class OpKernel { class OperatorWithKernel : public OperatorBase { public: + OperatorWithKernel(const std::string& type, + const std::vector& inputs, + const std::vector& outputs, + const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + struct OpKernelKey { platform::Place place_; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 387aada74..a538abe7f 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -23,6 +23,13 @@ static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { public: + OpWithoutKernelTest(const std::string& type, + const std::vector& inputs, + const std::vector& outputs, + const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void Init() override { x = 1; } void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, @@ -116,6 +123,13 @@ class CPUKernelTest : public OpKernel { // multiple inputs test class OperatorMultiInputsTest : public OperatorBase { public: + OperatorMultiInputsTest(const std::string& type, + const std::vector& inputs, + const std::vector& outputs, + const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void Init() override { x = 1; } void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, -- GitLab