From 14424f314c4d2018b49ad242c82738a21d2fe9e3 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Thu, 20 Jul 2017 10:03:53 +0800 Subject: [PATCH] "use built-in operator" --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/net_op_test.cc | 19 +++++++++++++------ paddle/framework/op_registry.h | 6 +++--- paddle/operators/add_op.cc | 15 +++++++++++++++ paddle/operators/add_op_test.cc | 7 ++++++- paddle/operators/mul_op.cc | 13 +++++++++++++ paddle/operators/sigmoid_op.cc | 13 +++++++++++++ 7 files changed, 64 insertions(+), 11 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index d018ee50c0..b56107daf1 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -29,4 +29,4 @@ add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) # cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op) cc_library(net SRCS net.cc DEPS operator net_proto op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net my_fc_op) +cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op) diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index 18151c56d9..2e74235261 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -2,7 +2,10 @@ #include #include #include -#include "paddle/framework/fully_connected_op.h" + +USE_OP(add_two); +USE_OP(mul); +USE_OP(sigmoid); namespace paddle { namespace framework { @@ -65,14 +68,18 @@ TEST(OpKernel, all) { ASSERT_THROW(net->AddOp(op2), EnforceNotMet); } - TEST(AddBackwardOp, TestGradOp) { auto net = std::make_shared(); ASSERT_NE(net, nullptr); - auto op1 = std::make_shared(); - op1->inputs_ = {"x", "w1", "b1"}; - op1->outputs_ = {"y"}; - net->AddOp(op1); + net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {})); + net->AddOp( + framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); + net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {})); + // net->AddOp(framework::OpRegistry::CreateOp("fc"), { + // Input("X"), Input("W"), Input("b")}, + // {Output("Y")}, + // {} + // ); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) { op->DebugString(); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 92354f4ffd..07c3399462 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -470,11 +470,11 @@ class GradOpRegisterHelper { */ #define REGISTER_GRADIENT_OP(__op_type, __op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op__##__op_type, \ + __reg_gradient_op__##__op_type, \ "REGISTER_GRADIENT_OP must be in global namespace"); \ static ::paddle::framework::GradOpRegisterHelper<__op_class> \ - __op_register_##__op_type##__(#__op_type); \ - int __op_register_##__op_type##_handle__() { return 0; } + __op_gradient_register_##__op_type##__(#__op_type); \ + int __op_gradient_register_##__op_type##_handle__() { return 0; } /** * Macro to Register OperatorKernel. diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 41d044cdb7..f59a027407 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -49,10 +49,25 @@ The equation is: Out = X + Y )DOC"); } }; + +class AddOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "AddOpGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); +REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad); + typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float> AddKernel_CPU_float; REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float); +// REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float); diff --git a/paddle/operators/add_op_test.cc b/paddle/operators/add_op_test.cc index 53b354fedc..7fc1049893 100644 --- a/paddle/operators/add_op_test.cc +++ b/paddle/operators/add_op_test.cc @@ -16,8 +16,13 @@ limitations under the License. */ #define private public #include USE_OP(add_two); +// USE_OP(add_two_grad); + TEST(AddOp, GetOpProto) { auto& protos = paddle::framework::OpRegistry::protos(); auto it = protos.find("add_two"); ASSERT_NE(it, protos.end()); -} \ No newline at end of file + auto& grad_creators = paddle::framework::OpRegistry::grad_creators(); + auto it1 = grad_creators.find("add_two"); + ASSERT_NE(it1, grad_creators.end()); +} diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 713b2a5dc8..ebf345194c 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -52,9 +52,22 @@ The equation is: Out = X * Y } }; +class MulOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "MulGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker); +REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad); + REGISTER_OP_CPU_KERNEL( mul, paddle::operators::MulKernel); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 45ae277c53..16348db020 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -39,11 +39,24 @@ public: } }; +class SigmoidOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "SigmoidGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(sigmoid, paddle::operators::SigmoidOp, paddle::operators::SigmoidOpMaker); +REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad); + REGISTER_OP_CPU_KERNEL( sigmoid, paddle::operators::SigmoidKernel); -- GitLab