diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 1677a3ed4c85ef293f0aadc64a4caa809cbd6ced..b930b86ed6b714b56ffa98dbbadbf1b2e806ff2e 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -30,6 +30,11 @@ using DeviceContext = platform::DeviceContext; class EmptyOp : public OperatorBase { public: + EmptyOp(const std::string &type, const std::vector &inputs, + const std::vector &outputs, const AttributeMap &attrs, + std::unordered_map *in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void InferShape(const Scope &scope) const override {} void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {} }; diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index f1ebbae52f13d9c0fc9408aec8c4160575ad59c0..c3ce69a344460ac94aa942d9d2bb5c7a3a1eaf05 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -10,6 +10,11 @@ namespace framework { class NOP : public OperatorBase { public: + NOP(const std::string &type, const std::vector &inputs, + const std::vector &outputs, const AttributeMap &attrs, + std::unordered_map *in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void InferShape(const Scope &scope) const override {} void Run(const Scope &scope, const platform::DeviceContext &dev_ctx) const override {} diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 9894928a7aa19bc6c7ad8b230562fb9a681cfebd..de3435ad35eab9d45335f269709222311c80781e 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -7,6 +7,11 @@ namespace paddle { namespace framework { class CosineOp : public OperatorBase { public: + CosineOp(const std::string& type, const std::vector& inputs, + const std::vector& outputs, const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} void InferShape(const Scope& scope) const override {} @@ -27,6 +32,11 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: + MyTestOp(const std::string& type, const std::vector& inputs, + const std::vector& outputs, const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f5d167a16ec577f6989593122715ac5681d11eda..8b7f74367147a070e4b9f30f5934e31c551abe8d 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -63,6 +63,16 @@ class ExecutionContext; */ class OperatorBase { public: + OperatorBase(const std::string& type, const std::vector& inputs, + const std::vector& outputs, + const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : type_(type), + inputs_(input), + outputs_(output), + attrs_(attrs), + in_out_idxs_(in_out_idxs) {} + virtual ~OperatorBase() {} template @@ -109,6 +119,9 @@ class OperatorBase { const std::vector Inputs() const { return inputs_; } const std::vector Outputs() const { return outputs_; } const AttributeMap& Attrs() const { return attrs_; } + const std::unordered_map* InOutIdx() const { + return in_out_idxs_.get(); + } public: std::string type_; @@ -286,6 +299,13 @@ class OpKernel { class OperatorWithKernel : public OperatorBase { public: + OperatorWithKernel(const std::string& type, + const std::vector& inputs, + const std::vector& outputs, + const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + struct OpKernelKey { platform::Place place_; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 387aada749ba62246b44dedc050547c05955caa9..a538abe7fecb9f94d70a92faf838f764787a15ec 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -23,6 +23,13 @@ static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { public: + OpWithoutKernelTest(const std::string& type, + const std::vector& inputs, + const std::vector& outputs, + const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void Init() override { x = 1; } void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, @@ -116,6 +123,13 @@ class CPUKernelTest : public OpKernel { // multiple inputs test class OperatorMultiInputsTest : public OperatorBase { public: + OperatorMultiInputsTest(const std::string& type, + const std::vector& inputs, + const std::vector& outputs, + const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + void Init() override { x = 1; } void InferShape(const Scope& scope) const override {} void Run(const Scope& scope,