From 9a3c69c268d65c340ed5d0ae9aee1a90ca50ee84 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 4 Sep 2017 10:45:11 -0700 Subject: [PATCH] Revert "Remove `grad_op_type` in `REGISTER_OP`" --- doc/howto/dev/new_op_cn.md | 4 ++-- paddle/framework/backward.md | 2 +- paddle/framework/backward_test.cc | 12 +++++++----- paddle/framework/grad_op_builder_test.cc | 4 ++-- paddle/framework/op_registry.h | 25 ++++++++++++++---------- paddle/operators/add_op.cc | 2 +- paddle/operators/cross_entropy_op.cc | 3 ++- paddle/operators/gather_op.cc | 3 ++- paddle/operators/lookup_table_op.cc | 2 +- paddle/operators/mean_op.cc | 2 +- paddle/operators/minus_op.cc | 3 ++- paddle/operators/mul_op.cc | 2 +- paddle/operators/rowwise_add_op.cc | 2 +- paddle/operators/scale_op.cc | 2 +- paddle/operators/scatter_op.cc | 3 ++- paddle/operators/sigmoid_op.cc | 3 ++- paddle/operators/softmax_op.cc | 3 ++- 17 files changed, 45 insertions(+), 32 deletions(-) diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md index ec79b7f42b2..7f8da2da5a0 100644 --- a/doc/howto/dev/new_op_cn.md +++ b/doc/howto/dev/new_op_cn.md @@ -178,13 +178,13 @@ class MulKernel : public framework::OpKernel { ```c++ namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad); +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); REGISTER_OP_CPU_KERNEL(mul_grad, ops::MulGradKernel); ``` - - `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker`为`ops::MulOpMaker`,并且注册`ops::MulOpGrad`为其反向Op。 + - `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker`为`ops::MulOpMaker`,注册`ops::MulOpGrad`,类型名为`mul_grad`, - `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。 - `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::MulKernel`类。 diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md index 9500c92a265..8aa6728a95b 100644 --- a/paddle/framework/backward.md +++ b/paddle/framework/backward.md @@ -18,7 +18,7 @@ A backward network is built up with several backward operators. Backward operato For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro: ```cpp -REGISTER_OP(mul, MulOp, MulOpMaker, MulOpGrad); +REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad); ``` `mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively. diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index bf8b11e5f5a..ad8003420dc 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -148,14 +148,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker { namespace f = paddle::framework; namespace ops = paddle::operators; using EnforceNotMet = paddle::platform::EnforceNotMet; -REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, f::NOP); -REGISTER_OP(mul, f::NOP, f::MulOpMaker, f::NOP); -REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, f::NOP); +REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad, + f::NOP); +REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); +REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); -REGISTER_OP(add, f::NOP, f::AddOpMaker, f::NOP); +REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); -REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, f::NOP); +REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, + f::NOP); TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp( diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 8a817a3e13c..902c2655e91 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -54,8 +54,8 @@ TEST(GradOpBuilder, AddTwo) { EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); } -REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, f::NOP); -REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, f::NOP); +REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP); +REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP); TEST(GradOpBuilder, MutiInOut) { std::shared_ptr test_op(f::OpRegistry::CreateOp( diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 64c7f23ab6b..2d09cde41e3 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -33,7 +33,8 @@ namespace framework { class OpRegistry { public: template - static void RegisterOp(const std::string& op_type) { + static void RegisterOp(const std::string& op_type, + const std::string& grad_op_type) { PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); OpInfo op_info; @@ -42,9 +43,9 @@ class OpRegistry { const VariableNameMap& outputs, const AttributeMap& attrs) { return new OpType(type, inputs, outputs, attrs); }; + op_info.grad_op_type_ = grad_op_type; if (std::type_index(typeid(ProtoMakerType)) != std::type_index(typeid(NOPMaker))) { - op_info.grad_op_type_ = op_type + "_grad"; op_info.proto_ = new OpProto; op_info.checker_ = new OpAttrChecker; auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); @@ -54,14 +55,15 @@ class OpRegistry { op_info.proto_->IsInitialized(), "Fail to initialize %s's OpProto, because %s is not initialized", op_type, op_info.proto_->InitializationErrorString()); - // register gradient op - RegisterOp(op_info.grad_op_type_); } else { - op_info.grad_op_type_ = ""; op_info.proto_ = nullptr; op_info.checker_ = nullptr; } OpInfoMap::Instance().Insert(op_type, op_info); + // register gradient op + if (!grad_op_type.empty()) { + RegisterOp(grad_op_type, ""); + } } static std::unique_ptr CreateOp(const std::string& type, @@ -90,8 +92,10 @@ class Registrar { template class OpRegistrar : public Registrar { public: - explicit OpRegistrar(const char* op_type) { - OpRegistry::RegisterOp(op_type); + explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } + OpRegistrar(const char* op_type, const char* grad_op_type) { + OpRegistry::RegisterOp(op_type, + grad_op_type); } }; @@ -117,7 +121,8 @@ class OpKernelRegistrar : public Registrar { /** * Macro to register Operator. */ -#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_class) \ +#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ + grad_op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ class _OpClass_##op_type##_ : public op_class { \ @@ -132,14 +137,14 @@ class OpKernelRegistrar : public Registrar { }; \ static ::paddle::framework::OpRegistrar< \ _OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \ - __op_registrar_##op_type##__(#op_type); \ + __op_registrar_##op_type##__(#op_type, #grad_op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ return 0; \ } #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ - REGISTER_OP(op_type, op_class, op_maker_class, ::paddle::framework::NOP) + REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP) /** * Macro to register OperatorKernel. diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 6384d8c8ce1..8ab748ed71e 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -57,7 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, ops::AddOpGrad); +REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad); REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index ac76326262c..ab1e1c101a1 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -67,7 +67,8 @@ OnehotCrossEntropy Operator. namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, - ops::OnehotCrossEntropyOpMaker, ops::OnehotCrossEntropyGradientOp); + ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, ops::OnehotCrossEntropyOpKernel); REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 07fa7048241..123bed296c4 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -63,7 +63,8 @@ Out = X[Index] } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, ops::GatherGradOp); +REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, + ops::GatherGradOp); REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index c3108ba8ec7..94d40890a76 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -66,7 +66,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker, - ops::LookupTableOpGrad); + lookup_table_grad, ops::LookupTableOpGrad); REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel); REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel); diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index e66e0abb25f..d3d0e55a674 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -54,7 +54,7 @@ class MeanGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradOp); +REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); REGISTER_OP_CPU_KERNEL(mean_grad, diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index b4afebcd97a..1eee9644bab 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -81,6 +81,7 @@ class MinusGradOp : public NetOp { USE_OP(scale); USE_OP_ITSELF(identity); namespace ops = paddle::operators; -REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradOp); +REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, + ops::MinusGradOp); REGISTER_OP_CPU_KERNEL(minus, ops::MinusKernel); diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 8d0f59745fd..5b8b5f6c118 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -84,7 +84,7 @@ class MulOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad); +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); REGISTER_OP_CPU_KERNEL(mul_grad, ops::MulGradKernel); diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 63de91254f4..6825dce332a 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -74,7 +74,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker, - ops::RowwiseAddGradOp); + rowwise_add_grad, ops::RowwiseAddGradOp); REGISTER_OP_CPU_KERNEL( rowwise_add, ops::RowwiseAddKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index 4e039688d4d..8e96a74c94a 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -97,7 +97,7 @@ class IdentityOp : public NetOp { namespace ops = paddle::operators; -REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, +REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, scale_grad, ops::ScaleGradOp); REGISTER_OP_CPU_KERNEL(scale, ops::ScaleKernel); diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index 35c185ad80f..f901edefa22 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -77,7 +77,8 @@ Out[Index] = Ref[Index] + Updates } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, ops::ScatterGradOp); +REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad, + ops::ScatterGradOp); REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index f35b7023845..761c6de8d4d 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -53,7 +53,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, ops::SigmoidOpGrad); +REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, + ops::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 471bb288fb2..40c51a64c49 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -62,7 +62,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ops::SoftmaxOpGrad); +REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad, + ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); REGISTER_OP_CPU_KERNEL( -- GitLab