diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 1a24d266db14bf142e3227a2ace0daa5126224cb..b6c46302b1f39b79f60e6f6accb4a3e6becb001f 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -79,11 +79,11 @@ static std::shared_ptr BackwardImpl( std::unordered_map /*op offset*/> dup_output_ops; - size_t local_op_id = 0; // Because it is a net op, it can static_cast. auto& forwardNet = static_cast(forwardOp); // travesal subnet/op + size_t local_op_id = 0; for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); ++it) { auto fwd = *it; diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 7472a970b916b498491e1f5809d8d84b896cee01..cb1d402526e0b0008263c8f19e4d175c6e1502f6 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -149,7 +149,6 @@ TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); - LOG(INFO) << gop->DebugString(); ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ("rowwise_add_grad", gop->type_); @@ -161,18 +160,19 @@ TEST(Backward, simple_op_grad) { // LOG(INFO) << gop->Output("X" + "@GRAD"); } -TEST(Backward, simple_net_grad) { - auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); +TEST(Backward, simple_op_not_need_grad) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"x", "b"}, {"out"}, {}); ASSERT_NE(fwd, nullptr); - auto gop = f::Backward(*fwd, {}); + auto gop = f::Backward(*fwd, {"x"}); LOG(INFO) << gop->DebugString(); + ASSERT_NE(gop->outputs_.find("x" + f::OperatorBase::GRAD_VAR_SUFFIX()), + gop->outputs_.end()); } TEST(Backward, net_fc_backward_normal) { std::shared_ptr fwd = f::OpRegistry::CreateOp( "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); ASSERT_NE(fwd, nullptr); - LOG(INFO) << fwd->DebugString(); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); auto net = static_cast(gop.get());