未验证 提交 bf1dc548 编写于 作者: W weishengying 提交者: GitHub

remove all control_vars in IR graph (#46888)

上级 81b5c2a2
...@@ -75,7 +75,6 @@ Graph::Graph(const ProgramDesc &program, ...@@ -75,7 +75,6 @@ Graph::Graph(const ProgramDesc &program,
} }
} else { } else {
auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index); auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index);
ResolveHazard(var_nodes);
} }
} }
...@@ -88,7 +87,6 @@ Graph::Graph(const BlockDesc &block, ...@@ -88,7 +87,6 @@ Graph::Graph(const BlockDesc &block,
const int64_t end_op_index) const int64_t end_op_index)
: main_graph_(main_graph) { : main_graph_(main_graph) {
auto var_nodes = InitFromBlock(block, start_op_index, end_op_index); auto var_nodes = InitFromBlock(block, start_op_index, end_op_index);
ResolveHazard(var_nodes);
} }
// TODO(levi): delete this interface after when we can convert all // TODO(levi): delete this interface after when we can convert all
......
...@@ -130,86 +130,6 @@ TEST(GraphTest, Basic) { ...@@ -130,86 +130,6 @@ TEST(GraphTest, Basic) {
ASSERT_EQ(nodes.size(), 5UL); ASSERT_EQ(nodes.size(), 5UL);
} }
TEST(GraphTest, WriteAfterRead) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, WriteAfterWrite) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, TestException) { TEST(GraphTest, TestException) {
ProgramDesc prog; ProgramDesc prog;
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
...@@ -350,12 +270,13 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -350,12 +270,13 @@ TEST(GraphTest, TestMultiBlock) {
op = prog.MutableBlock(1)->AppendOp(); op = prog.MutableBlock(1)->AppendOp();
op->SetType("dummy"); op->SetType("dummy");
op->SetInput("X", {"c"}); op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"}); op->SetOutput("Out", {"d"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
// Set contents in block_2. // Set contents in block_2.
op = prog.MutableBlock(2)->AppendOp(); op = prog.MutableBlock(2)->AppendOp();
...@@ -367,12 +288,13 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -367,12 +288,13 @@ TEST(GraphTest, TestMultiBlock) {
op = prog.MutableBlock(2)->AppendOp(); op = prog.MutableBlock(2)->AppendOp();
op->SetType("dummy"); op->SetType("dummy");
op->SetInput("X", {"c"}); op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"}); op->SetOutput("Out", {"d"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
// Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs. // Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs.
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
...@@ -399,45 +321,29 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -399,45 +321,29 @@ TEST(GraphTest, TestMultiBlock) {
// Check contents in sub_graph_1. // Check contents in sub_graph_1.
const ir::Graph *g1 = g->GetSubGraph(1); const ir::Graph *g1 = g->GetSubGraph(1);
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g1->Nodes()) { for (ir::Node *n : g1->Nodes()) {
if (n->Name() == "sum") { if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b"); ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); ASSERT_EQ(n->outputs.size(), 1UL);
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
} }
if (n->Name() == "dummy") { if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c"); ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); ASSERT_EQ(n->inputs.size(), 1UL);
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
} }
} }
ASSERT_EQ(control_dep1, control_dep2);
// Check contents in sub_graph_2. // Check contents in sub_graph_2.
const ir::Graph *g2 = g->GetSubGraph(2); const ir::Graph *g2 = g->GetSubGraph(2);
control_dep1 = nullptr;
control_dep2 = nullptr;
for (ir::Node *n : g2->Nodes()) { for (ir::Node *n : g2->Nodes()) {
if (n->Name() == "sum") { if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b"); ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); ASSERT_EQ(n->outputs.size(), 1UL);
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
} }
if (n->Name() == "dummy") { if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c"); ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); ASSERT_EQ(n->inputs.size(), 1UL);
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
} }
} }
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
// Step3: Clone graph. // Step3: Clone graph.
std::shared_ptr<ir::Graph> clone_g = g->Clone(); std::shared_ptr<ir::Graph> clone_g = g->Clone();
......
...@@ -331,8 +331,6 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { ...@@ -331,8 +331,6 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
copy_node(node); copy_node(node);
} }
} }
result.ResolveHazard(created);
} }
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册