diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md index ec79b7f42b2d70df8fcb25faca5bc3a4759e177c..7f8da2da5a0d42ff065265c5d173d0e6167dc08a 100644 --- a/doc/howto/dev/new_op_cn.md +++ b/doc/howto/dev/new_op_cn.md @@ -178,13 +178,13 @@ class MulKernel : public framework::OpKernel { ```c++ namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad); +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); REGISTER_OP_CPU_KERNEL(mul_grad, ops::MulGradKernel); ``` - - `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker`为`ops::MulOpMaker`,并且注册`ops::MulOpGrad`为其反向Op。 + - `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker`为`ops::MulOpMaker`,注册`ops::MulOpGrad`,类型名为`mul_grad`, - `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。 - `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::MulKernel`类。 diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md index 9500c92a265d60a696e1e2c422d0f2bd1621ef71..8aa6728a95bc464ab8884986f0cec6c817d3303b 100644 --- a/paddle/framework/backward.md +++ b/paddle/framework/backward.md @@ -18,7 +18,7 @@ A backward network is built up with several backward operators. Backward operato For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro: ```cpp -REGISTER_OP(mul, MulOp, MulOpMaker, MulOpGrad); +REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad); ``` `mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively. diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index bf8b11e5f5ae801621f84bdbeffb5c4cf2dd8905..ad8003420dc14538d0dae9a1cb19d6459b154576 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -148,14 +148,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker { namespace f = paddle::framework; namespace ops = paddle::operators; using EnforceNotMet = paddle::platform::EnforceNotMet; -REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, f::NOP); -REGISTER_OP(mul, f::NOP, f::MulOpMaker, f::NOP); -REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, f::NOP); +REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad, + f::NOP); +REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); +REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); -REGISTER_OP(add, f::NOP, f::AddOpMaker, f::NOP); +REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); -REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, f::NOP); +REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, + f::NOP); TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp( diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 8a817a3e13ca64d6f8df566891a1059995e041ae..902c2655e9182d74a48ad13e17a39a3304d5fa57 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -54,8 +54,8 @@ TEST(GradOpBuilder, AddTwo) { EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); } -REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, f::NOP); -REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, f::NOP); +REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP); +REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP); TEST(GradOpBuilder, MutiInOut) { std::shared_ptr test_op(f::OpRegistry::CreateOp( diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 64c7f23ab6b79bad9533f566ca39db3cfd5ac5c5..2d09cde41e3f5086279f9441e0fdc52549bed5ab 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -33,7 +33,8 @@ namespace framework { class OpRegistry { public: template - static void RegisterOp(const std::string& op_type) { + static void RegisterOp(const std::string& op_type, + const std::string& grad_op_type) { PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); OpInfo op_info; @@ -42,9 +43,9 @@ class OpRegistry { const VariableNameMap& outputs, const AttributeMap& attrs) { return new OpType(type, inputs, outputs, attrs); }; + op_info.grad_op_type_ = grad_op_type; if (std::type_index(typeid(ProtoMakerType)) != std::type_index(typeid(NOPMaker))) { - op_info.grad_op_type_ = op_type + "_grad"; op_info.proto_ = new OpProto; op_info.checker_ = new OpAttrChecker; auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); @@ -54,14 +55,15 @@ class OpRegistry { op_info.proto_->IsInitialized(), "Fail to initialize %s's OpProto, because %s is not initialized", op_type, op_info.proto_->InitializationErrorString()); - // register gradient op - RegisterOp(op_info.grad_op_type_); } else { - op_info.grad_op_type_ = ""; op_info.proto_ = nullptr; op_info.checker_ = nullptr; } OpInfoMap::Instance().Insert(op_type, op_info); + // register gradient op + if (!grad_op_type.empty()) { + RegisterOp(grad_op_type, ""); + } } static std::unique_ptr CreateOp(const std::string& type, @@ -90,8 +92,10 @@ class Registrar { template class OpRegistrar : public Registrar { public: - explicit OpRegistrar(const char* op_type) { - OpRegistry::RegisterOp(op_type); + explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } + OpRegistrar(const char* op_type, const char* grad_op_type) { + OpRegistry::RegisterOp(op_type, + grad_op_type); } }; @@ -117,7 +121,8 @@ class OpKernelRegistrar : public Registrar { /** * Macro to register Operator. */ -#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_class) \ +#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ + grad_op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ class _OpClass_##op_type##_ : public op_class { \ @@ -132,14 +137,14 @@ class OpKernelRegistrar : public Registrar { }; \ static ::paddle::framework::OpRegistrar< \ _OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \ - __op_registrar_##op_type##__(#op_type); \ + __op_registrar_##op_type##__(#op_type, #grad_op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ return 0; \ } #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ - REGISTER_OP(op_type, op_class, op_maker_class, ::paddle::framework::NOP) + REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP) /** * Macro to register OperatorKernel. diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 6384d8c8ce13dae8b58ed1069d496dd8e93eaa8a..8ab748ed71e9a5dc0ee0259a78a2b886870bec5b 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -57,7 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, ops::AddOpGrad); +REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad); REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index ac76326262c88e2014cf64f7fb73b5a7338ab3e9..ab1e1c101a10e09a81f7785d2f1514822e3bdf15 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -67,7 +67,8 @@ OnehotCrossEntropy Operator. namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, - ops::OnehotCrossEntropyOpMaker, ops::OnehotCrossEntropyGradientOp); + ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, ops::OnehotCrossEntropyOpKernel); REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 07fa704824174f939e459093b245036771d9cd4f..123bed296c462c30bddd3bfbd530098fdbfe4856 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -63,7 +63,8 @@ Out = X[Index] } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, ops::GatherGradOp); +REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, + ops::GatherGradOp); REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index c3108ba8ec7ad85bd3485c135bf03e514bc66cd1..94d40890a765413e88a35a6ad995ca97ac84dcda 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -66,7 +66,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker, - ops::LookupTableOpGrad); + lookup_table_grad, ops::LookupTableOpGrad); REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel); REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel); diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index e66e0abb25f9b933025a6d098ed9dd9eb18a47a5..d3d0e55a674587fb04f43f24d0790de4358f035a 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -54,7 +54,7 @@ class MeanGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradOp); +REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); REGISTER_OP_CPU_KERNEL(mean_grad, diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index b4afebcd97a8efff70aaaa85bc2ec5455ddd05c5..1eee9644babbdfac68821ca774845ad8ebbd5aee 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -81,6 +81,7 @@ class MinusGradOp : public NetOp { USE_OP(scale); USE_OP_ITSELF(identity); namespace ops = paddle::operators; -REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradOp); +REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, + ops::MinusGradOp); REGISTER_OP_CPU_KERNEL(minus, ops::MinusKernel); diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 8d0f59745fd58eb975b952369a53e48584a45126..5b8b5f6c118cbe213e1783256a940dff6fdccc46 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -84,7 +84,7 @@ class MulOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad); +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); REGISTER_OP_CPU_KERNEL(mul_grad, ops::MulGradKernel); diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 63de91254f4b75587cb2fb29aeb8ff7358ba8e76..6825dce332adc0dc11dda187d1bd367875b8603e 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -74,7 +74,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker, - ops::RowwiseAddGradOp); + rowwise_add_grad, ops::RowwiseAddGradOp); REGISTER_OP_CPU_KERNEL( rowwise_add, ops::RowwiseAddKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index 4e039688d4d74f2a101fc91c747bd1e6ebec7ad2..8e96a74c94ab7ff4d8c3266695e5157aff67905b 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -97,7 +97,7 @@ class IdentityOp : public NetOp { namespace ops = paddle::operators; -REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, +REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, scale_grad, ops::ScaleGradOp); REGISTER_OP_CPU_KERNEL(scale, ops::ScaleKernel); diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index 35c185ad80f93d1005c1616dcffd2e61bcd54222..f901edefa22dc9a252e87116df756d04767a7162 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -77,7 +77,8 @@ Out[Index] = Ref[Index] + Updates } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, ops::ScatterGradOp); +REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad, + ops::ScatterGradOp); REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index f35b7023845bac52887d81a8f5c496cb5e7193aa..761c6de8d4d2150b30b97b58da95da3d5f33db63 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -53,7 +53,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, ops::SigmoidOpGrad); +REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, + ops::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 471bb288fb20f113aefb2a9e13eb805b161b0631..40c51a64c49bc064f55975ef6ced1d54070f1291 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -62,7 +62,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ops::SoftmaxOpGrad); +REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad, + ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); REGISTER_OP_CPU_KERNEL(