diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 1677a3ed4c85ef293f0aadc64a4caa809cbd6ced..da3b9c8bed7cd123f2f8ef982a5f0e23abcc0ec7 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -30,6 +30,8 @@ using DeviceContext = platform::DeviceContext; class EmptyOp : public OperatorBase { public: + DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase) + void InferShape(const Scope &scope) const override {} void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {} }; diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index f1ebbae52f13d9c0fc9408aec8c4160575ad59c0..19e552b7458c966d473bdee99515a2beee1f6089 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -10,6 +10,8 @@ namespace framework { class NOP : public OperatorBase { public: + DEFINE_OPERATOR_CTOR(NOP, OperatorBase) + void InferShape(const Scope &scope) const override {} void Run(const Scope &scope, const platform::DeviceContext &dev_ctx) const override {} diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 84bf325fedcfa919e509c0d19218a84fa46fab37..cb9164eec1788c2c19176115e8687bed49d8c0b6 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -69,18 +69,18 @@ class OpProtoAndCheckerMaker { VariableBuilder AddInput(const std::string& name, const std::string& comment) { - auto input = proto_->mutable_inputs()->Add(); - *input->mutable_name() = name; - *input->mutable_comment() = comment; + VarProto* input = proto_->add_inputs(); + input->set_name(name); + input->set_comment(comment); return VariableBuilder{input, [=] { this->SetHasMultipleInput(); }, nullptr}; } VariableBuilder AddOutput(const std::string& name, const std::string& comment) { - auto output = proto_->mutable_outputs()->Add(); - *output->mutable_name() = name; - *output->mutable_comment() = comment; + VarProto* output = proto_->add_outputs(); + output->set_name(name); + output->set_comment(comment); return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); }, [=] { this->SetHasTemporaryOutput(); }}; } @@ -89,17 +89,15 @@ class OpProtoAndCheckerMaker { TypedAttrChecker& AddAttr(const std::string& name, const std::string& comment, bool generated = false) { - auto attr = proto_->mutable_attrs()->Add(); - *attr->mutable_name() = name; - *attr->mutable_comment() = comment; + AttrProto* attr = proto_->add_attrs(); + attr->set_name(name); + attr->set_comment(comment); attr->set_generated(generated); attr->set_type(AttrTypeID()); return op_checker_->AddAttrChecker(name); } - void AddComment(const std::string& comment) { - *(proto_->mutable_comment()) = comment; - } + void AddComment(const std::string& comment) { proto_->set_comment(comment); } private: void SetHasMultiple(const std::string& in_out, bool* flag) { @@ -187,7 +185,7 @@ class OpRegistry { OpProto& op_proto = protos()[op_type]; auto maker = ProtoMakerType(&op_proto, &op_checker); maker.Validate(); - *op_proto.mutable_type() = op_type; + op_proto.set_type(op_type); PADDLE_ENFORCE( op_proto.IsInitialized(), "Fail to initialize %s's OpProto, because %s is not initialized", diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 9894928a7aa19bc6c7ad8b230562fb9a681cfebd..e64126c7093a8eebc219afa4979d941ddc1afc97 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -7,6 +7,8 @@ namespace paddle { namespace framework { class CosineOp : public OperatorBase { public: + DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase) + void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} void InferShape(const Scope& scope) const override {} @@ -27,6 +29,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: + DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase) + void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f5d167a16ec577f6989593122715ac5681d11eda..68e7fedcd6102435a3c30326aa91043b8abecb9e 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -63,6 +63,17 @@ class ExecutionContext; */ class OperatorBase { public: + OperatorBase() {} // TODO(yi): This constructor is to be removed. + OperatorBase(const std::string& type, const std::vector& inputs, + const std::vector& outputs, + const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : type_(type), + inputs_(inputs), + outputs_(outputs), + attrs_(attrs), + in_out_idxs_(in_out_idxs) {} + virtual ~OperatorBase() {} template @@ -109,6 +120,9 @@ class OperatorBase { const std::vector Inputs() const { return inputs_; } const std::vector Outputs() const { return outputs_; } const AttributeMap& Attrs() const { return attrs_; } + const std::unordered_map* InOutIdx() const { + return in_out_idxs_.get(); + } public: std::string type_; @@ -286,6 +300,14 @@ class OpKernel { class OperatorWithKernel : public OperatorBase { public: + OperatorWithKernel() {} // TODO(yi): This constructor is to be removed. + OperatorWithKernel(const std::string& type, + const std::vector& inputs, + const std::vector& outputs, + const AttributeMap& attrs, + std::unordered_map* in_out_idxs) + : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} + struct OpKernelKey { platform::Place place_; @@ -335,5 +357,15 @@ class OperatorWithKernel : public OperatorBase { virtual void InferShape(const InferShapeContext& ctx) const = 0; }; +#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \ + public: \ + Class() { /* TODO(yi): This constructor is to be removed. */ \ + } \ + Class(const std::string& type, const std::vector& inputs, \ + const std::vector& outputs, \ + const ::paddle::framework::AttributeMap& attrs, \ + std::unordered_map* in_out_idxs) \ + : ParentClass(type, inputs, outputs, attrs, in_out_idxs) {} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 387aada749ba62246b44dedc050547c05955caa9..7dbd5b14ab6ec89ae9940a3d12ec9d2b169153ad 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -23,6 +23,8 @@ static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { public: + DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, OperatorBase) + void Init() override { x = 1; } void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, @@ -97,6 +99,8 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { static int cpu_kernel_run_num = 0; class OpWithKernelTest : public OperatorWithKernel { + public: + DEFINE_OPERATOR_CTOR(OpWithKernelTest, OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext& ctx) const override {} }; @@ -116,6 +120,8 @@ class CPUKernelTest : public OpKernel { // multiple inputs test class OperatorMultiInputsTest : public OperatorBase { public: + DEFINE_OPERATOR_CTOR(OperatorMultiInputsTest, OperatorBase) + void Init() override { x = 1; } void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 086245ef62d759ab20a3684ddbc015f6c6258639..b886ded9bbd97dc1942c87d7603521e8d72e3f6c 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { class AddOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_EQ(ctx.InputSize(), 2); @@ -47,6 +48,7 @@ The equation is: Out = X + Y }; class AddOpGrad : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override {} }; diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index c813d54e17fa48aa4447ef76b918b7355be52b09..09aa589d3caf7ed7b790850b515d49afdd3e1467 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { class OnehotCrossEntropyOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, @@ -38,6 +39,8 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel { }; class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp, + framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { auto X_grad = ctx.Output(framework::GradVarName("X")); diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index 3759a886780e555ccdc6286c4b200a5d14214691..eda23a0ccfacd3a620412876e18f4ec47652bf9d 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { class FillZerosLikeOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL, diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index ef417ae2f06e8a9f10aed80674015e2ee448f4a3..893cf56e5cf0d99d3f3bfffe98734a868f9b7595 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -43,6 +43,7 @@ class GaussianRandomKernel : public framework::OpKernel { }; class GaussianRandomOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext& context) const override { auto* tensor = context.Output(0); diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 2ea049cb3605f4dedabb992ebc0e8aa276ad5e9a..f6abba7ab45728f74dcea1363035a729b2cd1d03 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { class MeanOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one"); @@ -39,6 +40,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { }; class MeanGradOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { ctx.Output(framework::GradVarName("X")) diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index db81fd555d1c7bea7c0c3bbd70266b4952ed3724..6115a3f3332dba419b56e74a737627483448a715 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { class MulOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs"); @@ -53,6 +54,7 @@ The equation is: Out = X * Y }; class MulOpGrad : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override {} std::string DebugString() const override { diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 792b336675fc97659d9a23358cf3d48ede56e54e..24c9e61c66933c6be5bf44b3537e00b70a33922f 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -35,6 +35,8 @@ namespace operators { */ class NetOp : public framework::OperatorBase { public: + DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase) + /** * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index 76bf79f9b51fd759da2d02cd90fa458a32be4178..0d5c3de798d0b580860d24ea9a61a6a4ede5d0ab 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -12,6 +12,8 @@ static int run_cnt = 0; class TestOp : public framework::OperatorBase { public: + DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase) + void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { @@ -21,6 +23,8 @@ class TestOp : public framework::OperatorBase { class EmptyOp : public framework::OperatorBase { public: + DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase) + void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {} }; diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index d1e60fed9cef3c6dccba3ad498fc3658a177b3f7..fdd9d005378e63b8d44803fb2b4be83d134c6a5b 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -100,6 +100,7 @@ class RecurrentGradientAlgorithm { }; class RecurrentOp final : public framework::OperatorBase { + DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase) public: void Init() override; diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 55ed1c2f4c316656de94b24dd95b053a89d7e74e..402f6340a04d9b423bb16431a99a2f2866d203bc 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { class RowWiseAddOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(ctx.InputSize() == 2UL, diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index f9a28ff8a6a06c5c239c4e6ec21eacb410cc162f..5b8093f0f77e0982a7ad25b42b299a6461712630 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { class SGDOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, "Input size of SGDOp must be two"); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index bc5e0bbb183f9bdf0a3fa8a5a02499282fbd6b98..a02e2dc39e8f0d3e31c22a5cafeff111d08aa905 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { class SigmoidOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input"); @@ -38,6 +39,7 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { }; class SigmoidOpGrad : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { ctx.Output(0)->Resize(ctx.Input(0)->dims()); diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 3dd4e86918a86f408e7867d15b4fdc8f9cbbb5ce..9b6a679642303a2cb34954ce16b4a5811acf0ec2 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { class SoftmaxOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL, @@ -42,6 +43,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { }; class SoftmaxOpGrad : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL, diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 405b84b76d2e24db25d2ff16e99495f2f132ef09..ea81ec053f8b9029114f7c98d292a778dc50c3e4 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -46,6 +46,7 @@ class CPUUniformRandomKernel : public framework::OpKernel { }; class UniformRandomOp : public framework::OperatorWithKernel { + DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel) protected: void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE(GetAttr("min") < GetAttr("max"),