From 65d2678720a8647f16e284f7890f7e63abfa046d Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Fri, 28 Jul 2017 11:28:33 +0800 Subject: [PATCH] "add simple net test" --- paddle/framework/backward.cc | 2 -- paddle/framework/backward_test.cc | 14 ++++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 2d9efdd5114..7e111551d90 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -86,8 +86,6 @@ static std::shared_ptr BackwardImpl( // travesal subnet/op for (auto it = forwardNet.ops_.end(); it != forwardNet.ops_.begin(); --it) { auto fwd = *it; - // for (auto& fwd : forwardNet.ops_) { - // auto bwd = Backward(*fwd, no_grad_names); auto bwd = Backward(*fwd, no_grad_names); net->AddOp(bwd); for (size_t i = 0; i < bwd->outputs_.size(); ++i) { diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 54acc475996..ada7c706829 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -63,10 +63,10 @@ class FcOp : public NetOp { public: void Init() override { AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, - {Output("before_act")}, {})); + {Output("mul_out")}, {})); auto b_name = Input("b"); if (b_name != EMPTY_VAR_NAME()) { - AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name}, + AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_out"), b_name}, {Output("before_act")}, {})); } AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")}, @@ -82,6 +82,7 @@ class FcOpMaker : public OpProtoAndCheckerMaker { AddInput("X", "x"); AddInput("W", "w"); AddInput("b", "b"); + AddOutput("mul_out", "mul output").SetTemporary(); AddOutput("before_act", "before act").SetTemporary(); AddOutput("Out", ""); AddComment(""); @@ -140,6 +141,7 @@ TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); + LOG(INFO) << gop->DebugString(); ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ("rowwise_add_grad", gop->type_); @@ -151,10 +153,18 @@ TEST(Backward, simple_op_grad) { // LOG(INFO) << gop->Output("X" + "@GRAD"); } +TEST(Backward, simple_net_grad) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + ASSERT_NE(fwd, nullptr); + auto gop = f::Backward(*fwd, {}); + LOG(INFO) << gop->DebugString(); +} + TEST(Backward, net_fc_backward_normal) { std::shared_ptr fwd = f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {}); ASSERT_NE(fwd, nullptr); + LOG(INFO) << fwd->DebugString(); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); auto net = static_cast(gop.get()); -- GitLab