From 044a82d8412098b969a83c12deb4e07426328102 Mon Sep 17 00:00:00 2001 From: levi131 <83750468+levi131@users.noreply.github.com> Date: Fri, 16 Jul 2021 15:23:37 +0800 Subject: [PATCH] Convert all blocks in program into SSAgraphs. (#33320) As the title, this PR converts all blocks in program into SSA sub graphs and it is guarded by flag --- paddle/fluid/framework/ir/graph.cc | 144 ++++++++++++++++--- paddle/fluid/framework/ir/graph.h | 156 ++++++++++++++++++++- paddle/fluid/framework/ir/graph_test.cc | 176 ++++++++++++++++++++++++ paddle/fluid/framework/ir/pass_test.cc | 109 +++++++++++++++ 4 files changed, 557 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index e8a3de1a88a..1f55f0aa3cb 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -17,6 +17,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/operator.h" +DEFINE_bool(convert_all_blocks, false, + "Convert all blocks in program into SSAgraphs"); + namespace paddle { namespace framework { namespace ir { @@ -24,16 +27,9 @@ namespace ir { Graph::Graph(const ProgramDesc &program) : Graph(program, 0, program.Block(0).AllOps().size()) {} -Graph::Graph(const ProgramDesc &program, int64_t start_op_index, - int64_t end_op_index) - : program_(program) { - auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index); - ResolveHazard(var_nodes); -} - -std::map> Graph::InitFromProgram( - const ProgramDesc &program, int64_t start_op_index, int64_t end_op_index) { - VLOG(3) << "block in program:" << program_.Size(); +Graph::Graph(const ProgramDesc &program, const int64_t start_op_index, + const int64_t end_op_index) + : program_(program), main_graph_(nullptr) { PADDLE_ENFORCE_GE(start_op_index, 0, platform::errors::InvalidArgument( "Required start_op_index >= 0, but received " @@ -44,16 +40,65 @@ std::map> Graph::InitFromProgram( "Required end_op_index >= start_op_index, but received " "end_op_index: %d < start_op_index: %d", end_op_index, start_op_index)); + PADDLE_ENFORCE_GE( + program_.Size(), 1, + platform::errors::InvalidArgument("Can't construct a graph from this " + "program, it doesn't have a block")); + + const int64_t block_op_size = program_.Block(0).AllOps().size(); + PADDLE_ENFORCE_LE(end_op_index, block_op_size, + platform::errors::InvalidArgument( + "Required end_op_index <= block_op_size, but received " + "end_op_index: %d > block_op_size: %d", + end_op_index, block_op_size)); + if (FLAGS_convert_all_blocks) { + // NOTE(levi): start_op_index and end_op_index only work on the first + // sub_graph. + std::unique_ptr first_sub_graph = std::make_unique( + program_.Block(0), this, start_op_index, end_op_index); + sub_graphs_.push_back(std::move(first_sub_graph)); + for (size_t idx = 1; idx < program_.Size(); ++idx) { + std::unique_ptr sub_graph = + std::make_unique(program_.Block(idx), this); + sub_graphs_.push_back(std::move(sub_graph)); + } + } else { + auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index); + ResolveHazard(var_nodes); + } +} + +Graph::Graph(const BlockDesc &block, const Graph *main_graph) + : Graph(block, main_graph, 0, block.AllOps().size()) {} + +Graph::Graph(const BlockDesc &block, const Graph *main_graph, + const int64_t start_op_index, const int64_t end_op_index) + : main_graph_(main_graph) { + auto var_nodes = InitFromBlock(block, start_op_index, end_op_index); + ResolveHazard(var_nodes); +} +// TODO(levi): delete this interface after when we can convert all +// blocks into sub_graphs. +std::map> Graph::InitFromProgram( + const ProgramDesc &program, const int64_t start_op_index, + const int64_t end_op_index) { + VLOG(3) << "block in program:" << program_.Size(); + return InitFromBlock(program.Block(0), start_op_index, end_op_index); +} + +std::map> Graph::InitFromBlock( + const BlockDesc &block, const int64_t start_op_index, + const int64_t end_op_index) { std::unordered_map all_vars; // var nodes for each var name, will have multiple versions in SSA std::map> var_nodes; - for (auto *var : program.Block(0).AllVars()) { + for (auto *var : block.AllVars()) { all_vars.emplace(var->Name(), var); } auto not_visited_vars = all_vars; - auto all_ops = program.Block(0).AllOps(); + auto all_ops = block.AllOps(); PADDLE_ENFORCE_LE( end_op_index, all_ops.size(), platform::errors::InvalidArgument( @@ -210,22 +255,77 @@ void Graph::ResolveHazard( } std::shared_ptr Graph::Clone() { - auto cloned_graph = std::make_shared(this->program_); - cloned_graph->ReleaseNodes(); - cloned_graph->num_node_created_ = 0; + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument( + "This graph is a sub_graph, and can't be cloned individually")); + if (FLAGS_convert_all_blocks) { + auto cloned_graph = std::make_shared(this->program_); + cloned_graph->ReleaseSubGraphs(); + for (size_t idx = 0; idx < this->program_.Size(); ++idx) { + cloned_graph->AddSubGraph(this->CloneSubGraph(idx)); + } + return cloned_graph; + } else { + auto cloned_graph = std::make_shared(this->program_); + cloned_graph->ReleaseNodes(); + cloned_graph->num_node_created_ = 0; + std::unordered_map origin_to_cloned; + for (auto *n : this->node_set_) { + PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( + "The node to be cloned is nullptr.")); + ir::Node *cloned_node = nullptr; + if (n->IsCtrlVar()) { + cloned_node = cloned_graph->CreateControlDepVar(); + } else if (!n->var_desc_ && !n->op_desc_) { // empty node + cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType()); + } else if (n->IsVar()) { + cloned_node = cloned_graph->CreateVarNode(n->Var()); + } else if (n->IsOp()) { + cloned_node = cloned_graph->CreateOpNode(n->Op()); + } + PADDLE_ENFORCE_NOT_NULL( + cloned_node, + platform::errors::InvalidArgument( + "Failed to clone new node from original node in graph.")); + origin_to_cloned[n] = cloned_node; + } + for (auto *n : this->node_set_) { + for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { + origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]); + } + for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) { + origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]); + } + } + return cloned_graph; + } +} + +std::unique_ptr Graph::CloneSubGraph(const size_t idx) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + PADDLE_ENFORCE_LT( + idx, this->sub_graphs_.size(), + platform::errors::InvalidArgument("Invalid sub_graph index")); + std::unique_ptr cloned_sub_graph = + std::make_unique(this->program_.Block(idx), this); + cloned_sub_graph->ReleaseNodes(); + cloned_sub_graph->num_node_created_ = 0; std::unordered_map origin_to_cloned; - for (auto *n : this->node_set_) { + for (auto *n : this->sub_graphs_.at(idx)->Nodes()) { PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( "The node to be cloned is nullptr.")); ir::Node *cloned_node = nullptr; if (n->IsCtrlVar()) { - cloned_node = cloned_graph->CreateControlDepVar(); + cloned_node = cloned_sub_graph->CreateControlDepVar(); } else if (!n->var_desc_ && !n->op_desc_) { // empty node - cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType()); + cloned_node = cloned_sub_graph->CreateEmptyNode(n->Name(), n->NodeType()); } else if (n->IsVar()) { - cloned_node = cloned_graph->CreateVarNode(n->Var()); + cloned_node = cloned_sub_graph->CreateVarNode(n->Var()); } else if (n->IsOp()) { - cloned_node = cloned_graph->CreateOpNode(n->Op()); + cloned_node = cloned_sub_graph->CreateOpNode(n->Op()); } PADDLE_ENFORCE_NOT_NULL( cloned_node, @@ -233,7 +333,7 @@ std::shared_ptr Graph::Clone() { "Failed to clone new node from original node in graph.")); origin_to_cloned[n] = cloned_node; } - for (auto *n : this->node_set_) { + for (auto *n : this->sub_graphs_.at(idx)->Nodes()) { for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]); } @@ -241,7 +341,7 @@ std::shared_ptr Graph::Clone() { origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]); } } - return cloned_graph; + return cloned_sub_graph; } bool IsControlDepVar(const ir::Node &var) { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 26ca64ba821..da588426676 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -25,6 +26,8 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/variant.h" +DECLARE_bool(convert_all_blocks); + namespace paddle { namespace framework { class OpDesc; @@ -78,10 +81,20 @@ namespace ir { */ class Graph { public: + // Construct a main_graph with some sub_graphs explicit Graph(const ProgramDesc &program); - // Construct a Graph with ops[start_op_index, end_op_index) - explicit Graph(const ProgramDesc &program, int64_t start_op_index, - int64_t end_op_index); + + // Construct a main_graph with some sub_graphs, and the 1st sub_graph is + // constructed with ops[start_op_index, end_op_index) + Graph(const ProgramDesc &program, const int64_t start_op_index, + const int64_t end_op_index); + + // Construct a sub_graph + Graph(const BlockDesc &block, const Graph *main_graph); + + // Construct a sub_graph with ops[start_op_index, end_op_index) + Graph(const BlockDesc &block, const Graph *main_graph, + const int64_t start_op_index, const int64_t end_op_index); virtual ~Graph() { for (auto &attr : attrs_) { @@ -94,11 +107,21 @@ class Graph { bool IsConstructedByPartialProgram() const { return is_partial_; } bool Has(const std::string &attr_name) const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->Has(attr_name); + } + } return attrs_.count(attr_name) > 0; } template AttrType &GetOrInit(const std::string &attr_name) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->GetOrInit(attr_name); + } + } if (!Has(attr_name)) { Set(attr_name, new AttrType); } @@ -107,6 +130,11 @@ class Graph { template AttrType &Get(const std::string &attr_name) const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->Get(attr_name); + } + } PADDLE_ENFORCE_EQ( Has(attr_name), true, platform::errors::PreconditionNotMet( @@ -123,6 +151,11 @@ class Graph { template void Set(const std::string &attr_name, AttrType *attr) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->Set(attr_name, attr); + } + } PADDLE_ENFORCE_EQ( attrs_.count(attr_name), 0, platform::errors::AlreadyExists( @@ -137,6 +170,11 @@ class Graph { template void SetNotOwned(const std::string &attr_name, AttrType *attr) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->SetNotOwned(attr_name, attr); + } + } PADDLE_ENFORCE_EQ( attrs_.count(attr_name), 0, platform::errors::AlreadyExists("The attribute %s to be set(not owned) " @@ -147,6 +185,11 @@ class Graph { } void Erase(const std::string &attr_name) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->Erase(attr_name); + } + } PADDLE_ENFORCE_NE( attrs_.count(attr_name), 0, platform::errors::NotFound( @@ -157,10 +200,22 @@ class Graph { attr_dels_.erase(attr_name); } - const std::unordered_set &Nodes() const { return node_set_; } + const std::unordered_set &Nodes() const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->Nodes(); + } + } + return node_set_; + } // Create a normal variable with non-null VarDesc. ir::Node *CreateVarNode(VarDesc *var_desc) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->CreateVarNode(var_desc); + } + } PADDLE_ENFORCE_NOT_NULL( var_desc, platform::errors::InvalidArgument( "The VarDesc used to create variable node is null.")); @@ -171,6 +226,11 @@ class Graph { // Create a normal runnable operator with OpDesc. ir::Node *CreateOpNode(OpDesc *op_desc) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->CreateOpNode(op_desc); + } + } PADDLE_ENFORCE_NOT_NULL( op_desc, platform::errors::InvalidArgument( "The OpDesc used to create operator node is null.")); @@ -183,6 +243,11 @@ class Graph { // var doesn't hold any data. Other than that, it's no different from // other var, considering dependency analysis. ir::Node *CreateControlDepVar() { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->CreateControlDepVar(); + } + } // TODO(panyx0718): control var name should be really unique. const std::string name = string::Sprintf( "%s@%llu", static_cast(ir::Node::kControlDepVarName), @@ -195,6 +260,11 @@ class Graph { // A more free style way of creating a graph node. Mostly use for test // or "copy" from another node. Avoid using it if possible. ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->CreateEmptyNode(name, type); + } + } auto *x = AddNode(new ir::Node(name, type)); x->SetId(num_node_created_++); return x; @@ -203,6 +273,11 @@ class Graph { // Clear all node information of the graph and return the ownership of the // nodes. std::vector> ReleaseNodes() { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->ReleaseNodes(); + } + } std::vector> ret; for (auto &n : nodes_) { ret.emplace_back(n.second.release()); @@ -213,6 +288,11 @@ class Graph { } std::unique_ptr RemoveNode(ir::Node *node) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->RemoveNode(node); + } + } PADDLE_ENFORCE_EQ(node_set_.find(node) != node_set_.end(), true, platform::errors::PreconditionNotMet( "The node to be removed does not exist.")); @@ -225,6 +305,11 @@ class Graph { // NOTE low performance, but simple and secure. Node *RetrieveNode(int id) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->RetrieveNode(id); + } + } for (auto &node : nodes_) { if (node.second->id() == id) { return node.second.get(); @@ -237,10 +322,22 @@ class Graph { // WARN: After a series of passes, the current graph can be quite // different from OriginProgram. Caller shouldn't assume much from // the returned OriginProgram. - const ProgramDesc &OriginProgram() const { return program_; } + const ProgramDesc &OriginProgram() const { + if (FLAGS_convert_all_blocks) { + if (!IsMainGraph()) { + return main_graph_->OriginProgram(); + } + } + return program_; + } // This method takes ownership of `node`. ir::Node *AddNode(ir::Node *node) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->AddNode(node); + } + } PADDLE_ENFORCE_EQ(node_set_.find(node) == node_set_.end(), true, platform::errors::PreconditionNotMet( "The node to be added already exists.")); @@ -256,12 +353,59 @@ class Graph { // WARN: The method only clones the graph structure, not its attributes. std::shared_ptr Clone(); + bool IsMainGraph() const { return main_graph_ == nullptr; } + + Graph *GetSubGraph(const size_t idx) const { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + PADDLE_ENFORCE_LT( + idx, sub_graphs_.size(), + platform::errors::InvalidArgument("Invalid sub_graph index")); + return sub_graphs_.at(idx).get(); + } + + size_t SubGraphsSize() const { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + return sub_graphs_.size(); + } + private: + // TODO(levi): delete this interface after when we can convert all + // blocks into sub_graphs. std::map> InitFromProgram( - const ProgramDesc &program, int64_t start_op_index, int64_t end_op_index); + const ProgramDesc &program, const int64_t start_op_index, + const int64_t end_op_index); + + std::map> InitFromBlock( + const BlockDesc &block, const int64_t start_op_index, + const int64_t end_op_index); + + void ReleaseSubGraphs() { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + sub_graphs_.clear(); + } + + void AddSubGraph(std::unique_ptr sub_graph) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + sub_graphs_.push_back(std::move(sub_graph)); + } + + std::unique_ptr CloneSubGraph(const size_t idx); // NOTE: program_ shouldn't be exposed to user. const ProgramDesc program_; + // NOTE: main_graph_ doesn't hold any node. It's used as a container of + // sub_graphs, and the sub_graph holds the nodes. + const Graph *main_graph_; // not owned. + std::vector> sub_graphs_; + std::map attrs_; std::map> attr_dels_; std::map> nodes_; diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 66507fe7caf..1ff67ae0fe0 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -264,5 +264,181 @@ TEST(GraphTest, TestAttrCopy) { ASSERT_FALSE(dst_g.Has(kFloatValue)); } +TEST(GraphTest, TestInterfaceConvertAllBlocks) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + ProgramDesc prog; + prog.MutableBlock(0)->Var("init_var")->SetType(proto::VarType::SELECTED_ROWS); + ir::Graph g(prog); + ASSERT_TRUE(g.IsMainGraph()); + + const std::string kIntValue = "int_value"; + const int INT_VALUE = 3; + g.Set(kIntValue, new int(INT_VALUE)); + ASSERT_TRUE(g.Has(kIntValue)); + ASSERT_EQ(g.GetOrInit(kIntValue), INT_VALUE); + ASSERT_EQ(g.Get(kIntValue), INT_VALUE); + g.Erase(kIntValue); + ASSERT_TRUE(!g.Has(kIntValue)); + g.SetNotOwned(kIntValue, new int(INT_VALUE)); + ASSERT_TRUE(g.Has(kIntValue)); + g.Erase(kIntValue); + + g.ReleaseNodes(); + ASSERT_EQ(g.Nodes().size(), 0UL); + g.CreateVarNode(new VarDesc("temp_var_desc_name")); + g.CreateOpNode(prog.MutableBlock(0)->AppendOp()); + g.CreateControlDepVar(); + g.CreateEmptyNode("temp_empty_node_name", ir::Node::Type::kVariable); + ASSERT_EQ(g.Nodes().size(), 4UL); + g.RemoveNode(g.RetrieveNode(1)); + ASSERT_EQ(g.Nodes().size(), 3UL); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + +TEST(GraphTest, TestMultiBlock) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + // Step1: Build a program with 3 blocks. + ProgramDesc prog; + ASSERT_EQ(prog.Size(), 1UL); + prog.AppendBlock(prog.Block(0)); + prog.AppendBlock(prog.Block(0)); + ASSERT_EQ(prog.Size(), 3UL); + + // Set contents in block_0. + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"test_a", "test_b", "test_c"}); + op->SetOutput("Out", {"test_out"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_out"); + op->InferVarType(prog.MutableBlock(0)); + ASSERT_EQ(proto::VarType::SELECTED_ROWS, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); + op->InferVarType(prog.MutableBlock(0)); + ASSERT_EQ(proto::VarType::LOD_TENSOR, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + // Set contents in block_1. + op = prog.MutableBlock(1)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(1)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"a"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + // Set contents in block_2. + op = prog.MutableBlock(2)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(2)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + // Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs. + std::unique_ptr g(new ir::Graph(prog)); + ASSERT_EQ(g->IsMainGraph(), true); + ASSERT_EQ(g->SubGraphsSize(), 3UL); + + // Check contents in sub_graph_0. + const ir::Graph *g0 = g->GetSubGraph(0); + std::vector nodes(g0->Nodes().begin(), g0->Nodes().end()); + for (ir::Node *n : nodes) { + if (n->Name() == "sum") { + ASSERT_EQ(n->inputs.size(), 3UL); + ASSERT_EQ(n->outputs.size(), 1UL); + } else if (n->Name() == "test_a" || n->Name() == "test_b" || + n->Name() == "test_c") { + ASSERT_EQ(n->inputs.size(), 0UL); + ASSERT_EQ(n->outputs.size(), 1UL); + } else if (n->Name() == "test_out") { + ASSERT_EQ(n->inputs.size(), 1UL); + ASSERT_EQ(n->outputs.size(), 0UL); + } + } + ASSERT_EQ(nodes.size(), 5UL); + + // Check contents in sub_graph_1. + const ir::Graph *g1 = g->GetSubGraph(1); + ir::Node *control_dep1 = nullptr; + ir::Node *control_dep2 = nullptr; + for (ir::Node *n : g1->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + control_dep1 = n->outputs[1]; + ASSERT_EQ(n->outputs.size(), 2UL); + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2UL); + } + } + ASSERT_EQ(control_dep1, control_dep2); + + // Check contents in sub_graph_2. + const ir::Graph *g2 = g->GetSubGraph(2); + control_dep1 = nullptr; + control_dep2 = nullptr; + for (ir::Node *n : g2->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + ASSERT_EQ(n->outputs.size(), 2UL); + control_dep1 = n->outputs[1]; + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2UL); + } + } + ASSERT_NE(control_dep1, nullptr); + ASSERT_NE(control_dep2, nullptr); + ASSERT_EQ(control_dep1, control_dep2); + + // Step3: Clone graph. + std::shared_ptr clone_g = g->Clone(); + ASSERT_EQ(clone_g->IsMainGraph(), true); + ASSERT_EQ(clone_g->SubGraphsSize(), 3UL); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc index 65b9c427869..616ba7f1a97 100644 --- a/paddle/fluid/framework/ir/pass_test.cc +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -135,6 +135,93 @@ TEST(PassTest, TestPassAttrCheck) { exception.npos); } +TEST(PassTest, TestPassAttrCheckConvertAllBlocks) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + ProgramDesc prog; + auto pass = PassRegistry::Instance().Get("test_pass"); + std::unique_ptr graph(new Graph(prog)); + std::string exception; + try { + graph.reset(pass->Apply(graph.release())); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("Required atrribute test_pass_attr for pass < " + "test_pass > is not set") != exception.npos); + + int val = 1; + graph.reset(new Graph(prog)); + pass->SetNotOwned("test_pass_attr", &val); + + for (std::string try_type : {"bool", "const int", "std::string"}) { + try { + if (try_type == "bool") { + pass->Get("test_pass_attr"); + } else if (try_type == "const int") { + pass->Get("test_pass_attr"); + } else if (try_type == "std::string") { + pass->Get("test_pass_attr"); + } + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + std::string msg = "Invalid type for attritube test_pass_attr, expected: " + + try_type + ", actual: int"; + ASSERT_TRUE(exception.find(msg) != exception.npos); + } + + try { + graph.reset(pass->Apply(graph.release())); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find( + "Required atrribute test_graph_attr for graph is not set") != + exception.npos); + + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 1; + graph.reset(pass->Apply(graph.release())); + ASSERT_EQ(graph->Get("copy_test_pass_attr"), 2); + ASSERT_EQ(graph->Get("copy_test_graph_attr"), 2); + + // Allow apply more than once. + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph.reset(pass->Apply(graph.release())); + + pass = PassRegistry::Instance().Get("test_pass"); + pass->SetNotOwned("test_pass_attr", &val); + graph.reset(new Graph(prog)); + BuildCircleGraph(graph.get()); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 2; + try { + pass->Apply(graph.release()); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("shouldn't contain cycle") != exception.npos); + + pass = PassRegistry::Instance().Get("test_pass"); + pass->Set("test_pass_attr", new int); + try { + pass->Set("test_pass_attr", new int); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE( + exception.find("Attribute test_pass_attr already set in the pass") != + exception.npos); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + class TestPassWithDefault : public Pass { protected: void ApplyImpl(ir::Graph* graph) const { @@ -160,6 +247,28 @@ TEST(PassTest, TestPassDefaultAttrCheck) { ASSERT_EQ(pass->Get("default_attr"), 3); } +TEST(PassTest, TestPassDefaultAttrCheckConvertAllBlocks) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + ProgramDesc prog; + // check if default value is set + auto pass = PassRegistry::Instance().Get("test_pass_default_attr"); + std::unique_ptr graph(new Graph(prog)); + ASSERT_EQ(pass->Get("default_attr"), 1); + graph.reset(pass->Apply(graph.release())); + ASSERT_EQ(graph->Get("copy_default_attr"), 2); + + // check if new value overrides default value + pass = PassRegistry::Instance().Get("test_pass_default_attr"); + pass->Set("default_attr", new int{3}); + ASSERT_EQ(pass->Get("default_attr"), 3); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + TEST(PassTest, TestPassRegistrarDeconstructor) { auto pass_registrary = new PassRegistrar( -- GitLab