From 81a352aff7bc6588b652023b290986ba02301df9 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Thu, 20 Jul 2017 10:31:44 +0800 Subject: [PATCH] "test fc without gradient" --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/net.cc | 2 -- paddle/framework/net_op_test.cc | 16 +++++++++++----- paddle/framework/op_registry.h | 8 +++++++- paddle/operators/softmax_op.cc | 13 +++++++++++++ 5 files changed, 32 insertions(+), 9 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index b56107daf1..5eec31197f 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -29,4 +29,4 @@ add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) # cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op) cc_library(net SRCS net.cc DEPS operator net_proto op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op) +cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index bb02dcbcee..8902e2bcf1 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -22,8 +22,6 @@ namespace framework { std::shared_ptr AddBackwardOp(std::shared_ptr ForwardOps) { auto grad_ops = std::make_shared(); - // std::shared_ptr grad_ops; - // grad_ops.reset(new PlainNet); for (auto& op : ForwardOps->ops_) { auto op_grad = OpRegistry::CreateGradOp(op); grad_ops->AddOp(op_grad); diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index 2e74235261..2f24816bf8 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -6,6 +6,7 @@ USE_OP(add_two); USE_OP(mul); USE_OP(sigmoid); +USE_OP(softmax); namespace paddle { namespace framework { @@ -75,16 +76,21 @@ TEST(AddBackwardOp, TestGradOp) { net->AddOp( framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {})); - // net->AddOp(framework::OpRegistry::CreateOp("fc"), { - // Input("X"), Input("W"), Input("b")}, - // {Output("Y")}, - // {} - // ); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) { op->DebugString(); } } +// TODO(zhihong): add fc grad without registering. +// TEST(AddBackwardOp, TestNoGradOp) { +// auto net = std::make_shared(); +// ASSERT_NE(net, nullptr); +// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"}, +// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) { +// op->DebugString(); +// } +// } + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 07c3399462..0aa1eca837 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -286,7 +286,13 @@ class OpRegistry { } static OperatorPtr CreateGradOp(OperatorPtr op) { - OperatorPtr grad_op(grad_creators().at(op->type_)()); + auto it = grad_creators().find(op->type_); + if (it == grad_creators().end()) { + LOG(INFO) << op->type_ << "does not has gradient op"; + return nullptr; + } + // OperatorPtr grad_op(grad_creators().at(op->type_)()); + OperatorPtr grad_op(it->second()); grad_op->type_ = op->type_; AssembleGradInOut(op, grad_op); diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 4ca7be359e..146326d283 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -40,10 +40,23 @@ public: } }; +class SoftmaxOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "SoftmaxOpGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); +REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad); + REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); -- GitLab