diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index d018ee50c0de446d81191384b2c5f74714554dc4..b56107daf1bcf10f891bf400e60105c65f743f4b 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -29,4 +29,4 @@ add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) # cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op) cc_library(net SRCS net.cc DEPS operator net_proto op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net my_fc_op) +cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op) diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index 18151c56d9acb3b10d5949f92b3e093d38c796e0..2e74235261211815c3e80637ec913f82c65b556f 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -2,7 +2,10 @@ #include #include #include -#include "paddle/framework/fully_connected_op.h" + +USE_OP(add_two); +USE_OP(mul); +USE_OP(sigmoid); namespace paddle { namespace framework { @@ -65,14 +68,18 @@ TEST(OpKernel, all) { ASSERT_THROW(net->AddOp(op2), EnforceNotMet); } - TEST(AddBackwardOp, TestGradOp) { auto net = std::make_shared(); ASSERT_NE(net, nullptr); - auto op1 = std::make_shared(); - op1->inputs_ = {"x", "w1", "b1"}; - op1->outputs_ = {"y"}; - net->AddOp(op1); + net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {})); + net->AddOp( + framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); + net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {})); + // net->AddOp(framework::OpRegistry::CreateOp("fc"), { + // Input("X"), Input("W"), Input("b")}, + // {Output("Y")}, + // {} + // ); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) { op->DebugString(); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 92354f4ffdbba8312da477908637c054f30e6b83..07c3399462d45b99751f2d0f09aede47b0cea4c6 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -470,11 +470,11 @@ class GradOpRegisterHelper { */ #define REGISTER_GRADIENT_OP(__op_type, __op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op__##__op_type, \ + __reg_gradient_op__##__op_type, \ "REGISTER_GRADIENT_OP must be in global namespace"); \ static ::paddle::framework::GradOpRegisterHelper<__op_class> \ - __op_register_##__op_type##__(#__op_type); \ - int __op_register_##__op_type##_handle__() { return 0; } + __op_gradient_register_##__op_type##__(#__op_type); \ + int __op_gradient_register_##__op_type##_handle__() { return 0; } /** * Macro to Register OperatorKernel. diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 41d044cdb72b5fb2a7f8654e8ad103778e0857d1..f59a027407d18c428c60b9e79eda7aa759d957d6 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -49,10 +49,25 @@ The equation is: Out = X + Y )DOC"); } }; + +class AddOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "AddOpGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); +REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad); + typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float> AddKernel_CPU_float; REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float); +// REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float); diff --git a/paddle/operators/add_op_test.cc b/paddle/operators/add_op_test.cc index 53b354fedcacf2176aed8b504daf2046bdf96bb6..7fc1049893e171a17af92da7e813b2463874c9de 100644 --- a/paddle/operators/add_op_test.cc +++ b/paddle/operators/add_op_test.cc @@ -16,8 +16,13 @@ limitations under the License. */ #define private public #include USE_OP(add_two); +// USE_OP(add_two_grad); + TEST(AddOp, GetOpProto) { auto& protos = paddle::framework::OpRegistry::protos(); auto it = protos.find("add_two"); ASSERT_NE(it, protos.end()); -} \ No newline at end of file + auto& grad_creators = paddle::framework::OpRegistry::grad_creators(); + auto it1 = grad_creators.find("add_two"); + ASSERT_NE(it1, grad_creators.end()); +} diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 713b2a5dc83d8dd5a3d944101591d75cb19fe04f..ebf345194c6d15fca20b21e35c42ad6c775255d8 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -52,9 +52,22 @@ The equation is: Out = X * Y } }; +class MulOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "MulGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker); +REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad); + REGISTER_OP_CPU_KERNEL( mul, paddle::operators::MulKernel); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 45ae277c538ca90716febaf2f3d92b560149d147..16348db020064447b4445f0075c0591972ca3091 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -39,11 +39,24 @@ public: } }; +class SigmoidOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "SigmoidGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(sigmoid, paddle::operators::SigmoidOp, paddle::operators::SigmoidOpMaker); +REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad); + REGISTER_OP_CPU_KERNEL( sigmoid, paddle::operators::SigmoidKernel);