diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index feb4e8e76b600ab3544d8903674972835527ad41..73ef55756c330bdbc3be89c436967b2a88625a43 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -95,19 +95,19 @@ TEST(GraphTest, Basic) { std::unique_ptr g(new ir::Graph(prog)); std::vector nodes(g->Nodes().begin(), g->Nodes().end()); - ASSERT_EQ(nodes[0]->Name(), "sum"); - ASSERT_EQ(nodes[0]->inputs[0]->Name(), "test_a"); - ASSERT_EQ(nodes[0]->inputs[1]->Name(), "test_b"); - ASSERT_EQ(nodes[0]->inputs[2]->Name(), "test_c"); - ASSERT_EQ(nodes[0]->outputs[0]->Name(), "test_out"); - ASSERT_EQ(nodes[1]->Name(), "test_a"); - ASSERT_EQ(nodes[1]->outputs[0]->Name(), "sum"); - ASSERT_EQ(nodes[2]->Name(), "test_b"); - ASSERT_EQ(nodes[2]->outputs[0]->Name(), "sum"); - ASSERT_EQ(nodes[3]->Name(), "test_c"); - ASSERT_EQ(nodes[3]->outputs[0]->Name(), "sum"); - ASSERT_EQ(nodes[4]->Name(), "test_out"); - ASSERT_EQ(nodes[4]->inputs[0]->Name(), "sum"); + for (ir::Node *n : nodes) { + if (n->Name() == "sum") { + ASSERT_EQ(n->inputs.size(), 3); + ASSERT_EQ(n->outputs.size(), 1); + } else if (n->Name() == "test_a" || n->Name() == "test_b" || + n->Name() == "test_c") { + ASSERT_EQ(n->inputs.size(), 0); + ASSERT_EQ(n->outputs.size(), 1); + } else if (n->Name() == "test_out") { + ASSERT_EQ(n->inputs.size(), 1); + ASSERT_EQ(n->outputs.size(), 0); + } + } ASSERT_EQ(nodes.size(), 5); } } // namespace framework