diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 9886679d309b74a23a3f4e85e70a7ceec32c3ca8..f3d2c8d54b4ea43e214ee1536d0d2230e634f89e 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -165,33 +165,12 @@ TEST(Backward, simple_op_not_need_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::Backward(*fwd, {"X"}); - LOG(INFO) << "full " << gop->DebugString(); - ASSERT_NE(std::find(gop->outputs_.begin(), gop->outputs_.end(), - std::string("X") + f::OperatorBase::GRAD_VAR_SUFFIX()), + ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(), + "X" + f::OperatorBase::GRAD_VAR_SUFFIX()), gop->outputs_.end()); auto no_input_gop = f::Backward(*fwd, {"X", "b"}); - LOG(INFO) << "no input gop " << gop->DebugString(); ASSERT_NE(no_input_gop, nullptr); - - typedef std::vector Vec; - auto vector_equal = [](const Vec &l, const Vec &r) { - return l.size() == r.size(); - for (size_t i = 0; i < l.size(); ++i) { - if (l[i] != r[i]) return false; - } - return true; - }; - ASSERT_EQ(vector_equal(std::vector{}, no_input_gop->outputs_), - true); - ASSERT_EQ( - vector_equal( - std::vector{"Out" + f::OperatorBase::GRAD_VAR_SUFFIX()}, - no_input_gop->inputs_), - true); - // auto no_output_gop = f::Backward(*fwd, {"Out"}); - // ASSERT_EQ(std::vector{"X" + - // f::OperatorBase::GRAD_VAR_SUFFIX(), "b"}) } TEST(Backward, net_fc_backward_normal) { @@ -251,6 +230,8 @@ TEST(Backward, net_input_of_network_not_need_grad) { bwd_net->outputs_.begin(), bwd_net->outputs_.end()); all_output.erase(f::OperatorBase::EMPTY_VAR_NAME()); + LOG(INFO) << bwd_net->DebugString(); + LOG(INFO) << bwd_net->ops_.size(); for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()), all_output.end()); @@ -264,6 +245,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); ASSERT_EQ(3UL, first_fc_grad->ops_.size()); + LOG(INFO) << first_fc_grad->DebugString(); ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); } @@ -333,7 +315,7 @@ TEST(Backward, op_part_of_output_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto backward = f::Backward(*fwd, {"a"}); - ASSERT_TRUE(backward->IsNetOp()); + ASSERT_False(backward->IsNetOp()); auto net = static_cast(backward.get()); ASSERT_EQ(net->ops_.size(), 1UL);