From 2ea2fbea1bfb6f73c87f7029953ba8007e8cf4fb Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 11 Aug 2017 17:30:49 -0700 Subject: [PATCH] Merge REGISTER_OP and REGISTER_GRADIENT_OP --- paddle/framework/backward_test.cc | 16 ++++++---------- paddle/framework/grad_op_builder_test.cc | 13 ++----------- paddle/framework/op_registry.h | 22 +++++++++++++--------- paddle/framework/operator.h | 7 +++++++ paddle/operators/add_op.cc | 3 +-- paddle/operators/cross_entropy_op.cc | 5 ++--- paddle/operators/mean_op.cc | 3 +-- paddle/operators/mul_op.cc | 4 +--- paddle/operators/sigmoid_op.cc | 5 ++--- paddle/operators/softmax_op.cc | 4 ++-- 10 files changed, 37 insertions(+), 45 deletions(-) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 38194b716..4136e2c36 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -150,20 +150,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker { namespace f = paddle::framework; namespace ops = paddle::operators; using EnforceNotMet = paddle::platform::EnforceNotMet; -REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker, rowwise_add_grad); -REGISTER_GRADIENT_OP(rowwise_add_grad, f::EmptyOp); -REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker, mul_grad); -REGISTER_GRADIENT_OP(mul_grad, f::EmptyOp); -REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker, sigmoid_grad); -REGISTER_GRADIENT_OP(sigmoid_grad, f::EmptyOp); +REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker, rowwise_add_grad, + f::EmptyOp); +REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker, mul_grad, f::EmptyOp); +REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker, sigmoid_grad, f::EmptyOp); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::EmptyOp, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); -REGISTER_OP(add, f::EmptyOp, f::AddOpMaker, add_grad); -REGISTER_GRADIENT_OP(add_grad, f::EmptyOp); +REGISTER_OP(add, f::EmptyOp, f::AddOpMaker, add_grad, f::EmptyOp); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker, - many_output_op_grad); -REGISTER_GRADIENT_OP(many_output_op_grad, f::EmptyOp); + many_output_op_grad, f::EmptyOp); TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index ad61b482e..3d7f1a753 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -8,13 +8,6 @@ USE_OP(add_two); namespace paddle { namespace framework { -class NOP : public OperatorBase { - public: - void InferShape(const Scope &scope) const override {} - void Run(const Scope &scope, - const platform::DeviceContext &dev_ctx) const override {} -}; - class MutiInOutOpMaker : public OpProtoAndCheckerMaker { public: MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) @@ -61,10 +54,8 @@ TEST(GradOpBuilder, AddTwo) { EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); } -REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad); -REGISTER_GRADIENT_OP(mult_io_grad, f::NOP); -REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad); -REGISTER_GRADIENT_OP(io_ignored_grad, f::NOP); +REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP); +REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP); TEST(GradOpBuilder, MutiInOut) { f::AttributeMap attrs{{"input_format", std::vector{0, 1, 4, 5}}, diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 69c5f549e..080a7149b 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -193,7 +193,7 @@ class OpRegistry { using VarNameList = std::vector; public: - template + template static void RegisterOp(const std::string& op_type, const std::string& grad_op_type) { PADDLE_ENFORCE(op_info_map().count(op_type) == 0, @@ -226,6 +226,10 @@ class OpRegistry { // ================================================ // } op_info_map().insert(std::make_pair(op_type, op_info)); + // register gradient op + if (!grad_op_type.empty()) { + RegisterOp(grad_op_type, ""); + } } static std::shared_ptr CreateOp(const std::string& type, @@ -321,12 +325,13 @@ class Registrar { void Touch() {} }; -template +template class OpRegistrar : public Registrar { public: explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } OpRegistrar(const char* op_type, const char* grad_op_type) { - OpRegistry::RegisterOp(op_type, grad_op_type); + OpRegistry::RegisterOp(op_type, + grad_op_type); } }; @@ -352,10 +357,12 @@ class OpKernelRegistrar : public Registrar { /** * Macro to register Operator. */ -#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type) \ +#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ + grad_op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ - static ::paddle::framework::OpRegistrar \ + static ::paddle::framework::OpRegistrar \ __op_registrar_##op_type##__(#op_type, #grad_op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ @@ -363,10 +370,7 @@ class OpKernelRegistrar : public Registrar { } #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ - REGISTER_OP(op_type, op_class, op_maker_class, ) - -#define REGISTER_GRADIENT_OP(op_type, op_class) \ - REGISTER_OP(op_type, op_class, ::paddle::framework::NOPMaker, ) + REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP) /** * Macro to register OperatorKernel. diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f5d167a16..13308e0da 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -125,6 +125,13 @@ class OperatorBase { std::shared_ptr> in_out_idxs_; }; +class NOP : public OperatorBase { + public: + void InferShape(const Scope& scope) const override {} + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override {} +}; + class InferShapeContext { public: InferShapeContext(const OperatorBase& op, const Scope& scope) diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index e8e26cbe9..447e7b391 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -55,8 +55,7 @@ class AddOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad); -REGISTER_GRADIENT_OP(add_two_grad, ops::AddOpGrad); +REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad); REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 7d0e74e5e..3dcaccd75 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -69,12 +69,11 @@ OnehotCrossEntropy Operator. namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, - ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad); + ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL( onehot_cross_entropy, ops::OnehotCrossEntropyOpKernel); -REGISTER_GRADIENT_OP(onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL( onehot_cross_entropy_grad, ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 15e0708c4..c41208014 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -50,9 +50,8 @@ class MeanGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad); +REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); -REGISTER_GRADIENT_OP(mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel); diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 60550a274..0c4547f04 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -65,7 +65,5 @@ class MulOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad); -REGISTER_GRADIENT_OP(mul_grad, ops::MulOpGrad); - +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index fb27ffbfa..4f3a880b4 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -48,9 +48,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad); -REGISTER_GRADIENT_OP(sigmoid_grad, ops::SigmoidOpGrad); - +REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, + ops::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index abc21337c..99bc5b77d 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -64,9 +64,9 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad); +REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad, + ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); -REGISTER_GRADIENT_OP(softmax_grad, ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL( softmax_grad, ops::SoftmaxGradKernel); -- GitLab