From 658588a6755b8b036d87d6a89928a36dadfb7f00 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Fri, 28 Jul 2017 14:28:09 +0800 Subject: [PATCH] "format test case" --- paddle/framework/backward_test.cc | 52 +++++++++++++++++++------------ 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 69faee9fb7..9886679d30 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -167,15 +167,28 @@ TEST(Backward, simple_op_not_need_grad) { auto gop = f::Backward(*fwd, {"X"}); LOG(INFO) << "full " << gop->DebugString(); ASSERT_NE(std::find(gop->outputs_.begin(), gop->outputs_.end(), - "X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::string("X") + f::OperatorBase::GRAD_VAR_SUFFIX()), gop->outputs_.end()); + auto no_input_gop = f::Backward(*fwd, {"X", "b"}); - LOG(INFO) << "no input gop " << no_input_gop->DebugString(); + LOG(INFO) << "no input gop " << gop->DebugString(); ASSERT_NE(no_input_gop, nullptr); - ASSERT_EQ(std::vector{}, no_input_gop->outputs_); + + typedef std::vector Vec; + auto vector_equal = [](const Vec &l, const Vec &r) { + return l.size() == r.size(); + for (size_t i = 0; i < l.size(); ++i) { + if (l[i] != r[i]) return false; + } + return true; + }; + ASSERT_EQ(vector_equal(std::vector{}, no_input_gop->outputs_), + true); ASSERT_EQ( - std::vector{"Out" + f::OperatorBase::GRAD_VAR_SUFFIX()}, - no_input_gop->inputs_); + vector_equal( + std::vector{"Out" + f::OperatorBase::GRAD_VAR_SUFFIX()}, + no_input_gop->inputs_), + true); // auto no_output_gop = f::Backward(*fwd, {"Out"}); // ASSERT_EQ(std::vector{"X" + // f::OperatorBase::GRAD_VAR_SUFFIX(), "b"}) @@ -251,9 +264,8 @@ TEST(Backward, net_input_of_network_not_need_grad) { ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); ASSERT_EQ(3UL, first_fc_grad->ops_.size()); - ASSERT_EQ( - f::OperatorBase::EMPTY_VAR_NAME(), - first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX())); + ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), + first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); } TEST(Backward, net_shared_weight) { @@ -266,13 +278,14 @@ TEST(Backward, net_shared_weight) { ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); + LOG(INFO) << bwd_net->DebugString(); ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); } TEST(Backward, op_register_grad_not_for_network) { - auto fwd = f::OpRegistry::CreateOp( - "fc", {"X", "W", "b"}, {"mul_result", "add_result", "Out"}, - {{"temporary_index", std::vector{1}}}); + auto fwd = + f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, + {{"temporary_index", std::vector{1}}}); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); } @@ -320,9 +333,11 @@ TEST(Backward, op_part_of_output_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto backward = f::Backward(*fwd, {"a"}); - ASSERT_TRUE(!backward->IsNetOp()); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_EQ(net->ops_.size(), 1UL); - auto &grad_mul = *backward; + auto &grad_mul = *net->ops_[0]; ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL); @@ -339,13 +354,10 @@ TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) { f::NetOp net; - net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, - {"mul_out1", "add_out1", "out1"}, {})); - net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, - {"mul_out2", "tmp_out2", "out2"}, {})); - net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, - {"mul_out3", "tmp_out3", "out3"}, {})); - net.CompleteAddOp(); + net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"out1"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {"out2"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"out3"}, {})); + net.CompleteAddOp(false); auto backward = f::Backward(net, {"out2"}); ASSERT_TRUE(backward->IsNetOp()); auto bwd_net = static_cast(backward.get()); -- GitLab