From 74cd9a7542027a89b0751c2cb5c45bb8f413c52b Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Fri, 28 Jul 2017 13:57:31 +0800 Subject: [PATCH] "fix unittest" --- paddle/framework/backward.cc | 2 +- paddle/framework/backward_test.cc | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 1a24d266db1..b6c46302b1f 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -79,11 +79,11 @@ static std::shared_ptr BackwardImpl( std::unordered_map /*op offset*/> dup_output_ops; - size_t local_op_id = 0; // Because it is a net op, it can static_cast. auto& forwardNet = static_cast(forwardOp); // travesal subnet/op + size_t local_op_id = 0; for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); ++it) { auto fwd = *it; diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 7472a970b91..cb1d402526e 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -149,7 +149,6 @@ TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); - LOG(INFO) << gop->DebugString(); ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ("rowwise_add_grad", gop->type_); @@ -161,18 +160,19 @@ TEST(Backward, simple_op_grad) { // LOG(INFO) << gop->Output("X" + "@GRAD"); } -TEST(Backward, simple_net_grad) { - auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); +TEST(Backward, simple_op_not_need_grad) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"x", "b"}, {"out"}, {}); ASSERT_NE(fwd, nullptr); - auto gop = f::Backward(*fwd, {}); + auto gop = f::Backward(*fwd, {"x"}); LOG(INFO) << gop->DebugString(); + ASSERT_NE(gop->outputs_.find("x" + f::OperatorBase::GRAD_VAR_SUFFIX()), + gop->outputs_.end()); } TEST(Backward, net_fc_backward_normal) { std::shared_ptr fwd = f::OpRegistry::CreateOp( "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); ASSERT_NE(fwd, nullptr); - LOG(INFO) << fwd->DebugString(); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); auto net = static_cast(gop.get()); -- GitLab