未验证 提交 bf1dc548 编写于 作者: W weishengying 提交者: GitHub

remove all control_vars in IR graph (#46888)

上级 81b5c2a2
......@@ -75,7 +75,6 @@ Graph::Graph(const ProgramDesc &program,
}
} else {
auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index);
ResolveHazard(var_nodes);
}
}
......@@ -88,7 +87,6 @@ Graph::Graph(const BlockDesc &block,
const int64_t end_op_index)
: main_graph_(main_graph) {
auto var_nodes = InitFromBlock(block, start_op_index, end_op_index);
ResolveHazard(var_nodes);
}
// TODO(levi): delete this interface after when we can convert all
......
......@@ -130,86 +130,6 @@ TEST(GraphTest, Basic) {
ASSERT_EQ(nodes.size(), 5UL);
}
TEST(GraphTest, WriteAfterRead) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, WriteAfterWrite) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, TestException) {
ProgramDesc prog;
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
......@@ -350,12 +270,13 @@ TEST(GraphTest, TestMultiBlock) {
op = prog.MutableBlock(1)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"});
op->SetOutput("Out", {"d"});
op->SetAttr("op_role", 1);
prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
// Set contents in block_2.
op = prog.MutableBlock(2)->AppendOp();
......@@ -367,12 +288,13 @@ TEST(GraphTest, TestMultiBlock) {
op = prog.MutableBlock(2)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"});
op->SetOutput("Out", {"d"});
op->SetAttr("op_role", 1);
prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
// Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs.
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
......@@ -399,45 +321,29 @@ TEST(GraphTest, TestMultiBlock) {
// Check contents in sub_graph_1.
const ir::Graph *g1 = g->GetSubGraph(1);
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g1->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
ASSERT_EQ(n->outputs.size(), 1UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
ASSERT_EQ(n->inputs.size(), 1UL);
}
}
ASSERT_EQ(control_dep1, control_dep2);
// Check contents in sub_graph_2.
const ir::Graph *g2 = g->GetSubGraph(2);
control_dep1 = nullptr;
control_dep2 = nullptr;
for (ir::Node *n : g2->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 1UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
ASSERT_EQ(n->inputs.size(), 1UL);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
// Step3: Clone graph.
std::shared_ptr<ir::Graph> clone_g = g->Clone();
......
......@@ -331,8 +331,6 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
copy_node(node);
}
}
result.ResolveHazard(created);
}
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册