From 167523e75b8627976612fa8e166bd0e425414a0a Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Wed, 28 Jul 2021 15:06:38 +0800 Subject: [PATCH] graph_to_program topology sort (#33949) See https://github.com/PaddlePaddle/Paddle/pull/33949 for details --- paddle/fluid/framework/ir/graph.cc | 40 +- paddle/fluid/framework/ir/graph.h | 32 +- paddle/fluid/framework/ir/graph_helper.cc | 80 ++++ paddle/fluid/framework/ir/graph_helper.h | 2 + .../framework/ir/graph_to_program_pass.cc | 102 ++++- .../framework/ir/graph_to_program_pass.h | 3 + .../ir/graph_to_program_pass_test.cc | 382 ++++++++++++++++++ .../while_op_eager_deletion_pass.cc | 11 + paddle/fluid/framework/ir/node.cc | 2 +- paddle/fluid/framework/ir/node.h | 114 +++++- paddle/fluid/framework/ir/node_test.cc | 27 ++ 11 files changed, 762 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 1f55f0aa3cb..3914e08d995 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -56,10 +56,12 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index, // sub_graph. std::unique_ptr first_sub_graph = std::make_unique( program_.Block(0), this, start_op_index, end_op_index); + first_sub_graph->block_id_ = 0; sub_graphs_.push_back(std::move(first_sub_graph)); for (size_t idx = 1; idx < program_.Size(); ++idx) { std::unique_ptr sub_graph = std::make_unique(program_.Block(idx), this); + sub_graph->block_id_ = idx; sub_graphs_.push_back(std::move(sub_graph)); } } else { @@ -90,14 +92,32 @@ std::map> Graph::InitFromProgram( std::map> Graph::InitFromBlock( const BlockDesc &block, const int64_t start_op_index, const int64_t end_op_index) { - std::unordered_map all_vars; + std::unordered_map> + name_to_desc_block_id; + + const BlockDesc *block_var_visible = █ + while (block_var_visible != nullptr) { + for (auto *var : block_var_visible->AllVars()) { + name_to_desc_block_id.emplace( + var->Name(), std::make_pair(var, block_var_visible->ID())); + } + const BlockDesc *forward_block = block_var_visible->ForwardBlock(); + if (forward_block != nullptr) { + for (auto *var : forward_block->AllVars()) { + name_to_desc_block_id.emplace(var->Name(), + std::make_pair(var, forward_block->ID())); + } + } + block_var_visible = block_var_visible->ParentBlock(); + } // var nodes for each var name, will have multiple versions in SSA std::map> var_nodes; + std::unordered_map not_visited_vars; for (auto *var : block.AllVars()) { - all_vars.emplace(var->Name(), var); + not_visited_vars.emplace(var->Name(), var); } - auto not_visited_vars = all_vars; + int desc_order = 0; auto all_ops = block.AllOps(); PADDLE_ENFORCE_LE( end_op_index, all_ops.size(), @@ -109,6 +129,8 @@ std::map> Graph::InitFromBlock( auto *op = all_ops[i]; VLOG(3) << "create OpNode by " << op->Type(); ir::Node *node = CreateOpNode(op); + node->SetDescOrder(desc_order); + ++desc_order; // For input args, reuse the same var name if it was created before. // Otherwise, create a new one. for (auto &each_var_name : op->InputArgumentNames()) { @@ -116,8 +138,9 @@ std::map> Graph::InitFromBlock( ir::Node *var = nullptr; if (var_nodes.find(each_var_name) != var_nodes.end()) { var = var_nodes.at(each_var_name).back(); - } else if (all_vars.count(each_var_name) != 0) { - var = CreateVarNode(all_vars.at(each_var_name)); + } else if (name_to_desc_block_id.count(each_var_name) != 0) { + auto desc_and_block_id = name_to_desc_block_id.at(each_var_name); + var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second); var_nodes[each_var_name].push_back(var); } else { // Operation input var can be optional (dispensable). Which means @@ -143,8 +166,9 @@ std::map> Graph::InitFromBlock( } ir::Node *var = nullptr; - if (all_vars.count(each_var_name) != 0) { - var = CreateVarNode(all_vars.at(each_var_name)); + if (name_to_desc_block_id.count(each_var_name) != 0) { + auto desc_and_block_id = name_to_desc_block_id.at(each_var_name); + var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second); } else { // Operation output vars can be @EMPTY@. For example, while_grad // can have multi @EMPTY@ outputs with no VarDesc. @@ -270,6 +294,7 @@ std::shared_ptr Graph::Clone() { auto cloned_graph = std::make_shared(this->program_); cloned_graph->ReleaseNodes(); cloned_graph->num_node_created_ = 0; + cloned_graph->block_id_ = this->block_id_; std::unordered_map origin_to_cloned; for (auto *n : this->node_set_) { PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( @@ -313,6 +338,7 @@ std::unique_ptr Graph::CloneSubGraph(const size_t idx) { std::make_unique(this->program_.Block(idx), this); cloned_sub_graph->ReleaseNodes(); cloned_sub_graph->num_node_created_ = 0; + cloned_sub_graph->block_id_ = idx; std::unordered_map origin_to_cloned; for (auto *n : this->sub_graphs_.at(idx)->Nodes()) { PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index da588426676..50c5671cb91 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -104,7 +104,14 @@ class Graph { attr_dels_.clear(); } - bool IsConstructedByPartialProgram() const { return is_partial_; } + bool IsConstructedByPartialProgram() const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->IsConstructedByPartialProgram(); + } + } + return is_partial_; + } bool Has(const std::string &attr_name) const { if (FLAGS_convert_all_blocks) { @@ -210,7 +217,7 @@ class Graph { } // Create a normal variable with non-null VarDesc. - ir::Node *CreateVarNode(VarDesc *var_desc) { + ir::Node *CreateVarNode(VarDesc *var_desc, int block_id = -1) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { return GetSubGraph(0)->CreateVarNode(var_desc); @@ -219,7 +226,8 @@ class Graph { PADDLE_ENFORCE_NOT_NULL( var_desc, platform::errors::InvalidArgument( "The VarDesc used to create variable node is null.")); - auto *x = AddNode(new ir::Node(var_desc)); + auto *x = + AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id)); x->SetId(num_node_created_++); return x; } @@ -252,7 +260,7 @@ class Graph { const std::string name = string::Sprintf( "%s@%llu", static_cast(ir::Node::kControlDepVarName), num_node_created_); - auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable)); + auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_)); x->SetId(num_node_created_++); return x; } @@ -265,7 +273,7 @@ class Graph { return GetSubGraph(0)->CreateEmptyNode(name, type); } } - auto *x = AddNode(new ir::Node(name, type)); + auto *x = AddNode(new ir::Node(name, type, block_id_)); x->SetId(num_node_created_++); return x; } @@ -365,6 +373,15 @@ class Graph { return sub_graphs_.at(idx).get(); } + int GetBlockId() const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->block_id_; + } + } + return block_id_; + } + size_t SubGraphsSize() const { PADDLE_ENFORCE_EQ( this->IsMainGraph(), true, @@ -394,6 +411,9 @@ class Graph { PADDLE_ENFORCE_EQ( this->IsMainGraph(), true, platform::errors::InvalidArgument("This graph is not main_graph")); + PADDLE_ENFORCE_EQ(sub_graphs_.size(), sub_graph->block_id_, + platform::errors::InvalidArgument( + "sub_graph idx is not equal to block_id_")); sub_graphs_.push_back(std::move(sub_graph)); } @@ -416,6 +436,8 @@ class Graph { // parts: forward graph and backward graph, which can be executed // independently. bool is_partial_{false}; + // The block this SubGraph belongs to. + int block_id_{0}; }; bool IsControlDepVar(const ir::Node &var); diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index cfdda435e65..50174cfbbba 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/graph_helper.h" +#include #include DEFINE_string(print_sub_graph_dir, "", @@ -395,6 +396,85 @@ std::vector TopologyVarientSort(const Graph &graph, } } +class DescOrderComparator { + public: + bool operator()(const Node *n1, const Node *n2) { + return (n1->DescOrder() > n2->DescOrder()) || + ((n1->DescOrder() == n2->DescOrder()) && + (n1->ToString() > n2->ToString())); + } +}; + +std::vector TopologySortGraphByDescOrder(const Graph &graph) { + std::vector sorted_ops; + std::priority_queue, DescOrderComparator> q; + std::unordered_map> in_ops; + std::unordered_map> out_ops; + + // ensure all op node in 'in_ops' and 'out_ops' + for (const auto &n : graph.Nodes()) { + if (!n->IsOp()) continue; + + in_ops.emplace(n, std::unordered_set()); + out_ops.emplace(n, std::unordered_set()); + } + + // record all op's input op and output op + for (const auto &n : graph.Nodes()) { + if (!n->IsOp()) continue; + + // traverse all input op + for (const auto &var : n->inputs) { + for (const auto &in : var->inputs) { + // use at instead of [] to prevent no unrecorded op node + in_ops.at(n).insert(in); + out_ops.at(in).insert(n); + } + } + } + + // find topology entrance + for (const auto &n : graph.Nodes()) { + if (!n->IsOp()) continue; + + if (in_ops.at(n).empty()) { + q.push(n); + } + } + + // topological sorting + while (!q.empty()) { + // Do not get by reference!!! The element will pop later. + const auto cur_op = q.top(); + q.pop(); + + sorted_ops.push_back(cur_op); + for (const auto &out : out_ops.at(cur_op)) { + PADDLE_ENFORCE_GT(in_ops.at(out).count(cur_op), 0, + platform::errors::InvalidArgument( + "We find %s in %s's output list, " + "but cannot find %s in %s's input list. " + "Please ensure graph completely.", + out->Name().c_str(), cur_op->Name().c_str(), + cur_op->Name().c_str(), out->Name().c_str())); + in_ops.at(out).erase(cur_op); + + // push if in-degree is 0 + if (in_ops.at(out).empty()) { + q.push(out); + } + } + } + + PADDLE_ENFORCE_EQ( + sorted_ops.size(), in_ops.size(), + platform::errors::InvalidArgument("Topological sorting incompletely, " + "only sorted %zd op but total %zd.", + sorted_ops.size(), in_ops.size())); + + return sorted_ops; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index 0c43febca70..27a4fe25cd5 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -87,6 +87,8 @@ std::vector FilterByNodeWrapper(const Graph &graph) { return ret; } +std::vector TopologySortGraphByDescOrder(const Graph &graph); + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.cc b/paddle/fluid/framework/ir/graph_to_program_pass.cc index 944db2b772e..b31ccd48aa9 100644 --- a/paddle/fluid/framework/ir/graph_to_program_pass.cc +++ b/paddle/fluid/framework/ir/graph_to_program_pass.cc @@ -14,7 +14,13 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_to_program_pass.h" +#include +#include + #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_proto_maker.h" + +DECLARE_bool(convert_all_blocks); namespace paddle { namespace framework { @@ -27,13 +33,10 @@ namespace framework { namespace ir { void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const { - // Remove the unneeded variables after memory optimization. - std::unordered_set vars2remove; - if (graph->Has(kGraphToProgramVarsToRemove)) { - vars2remove = graph->Get>( - kGraphToProgramVarsToRemove); - VLOG(2) << "graph to program remove " << vars2remove.size() << " nodes"; - } + PADDLE_ENFORCE_EQ(graph->IsMainGraph(), true, + platform::errors::InvalidArgument( + "This graph is a sub_graph, " + "and can't convert to program individually")); ProgramDesc& program = Get("program"); @@ -42,12 +45,79 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const { auto block = program_pb->mutable_blocks(kRootBlockIndex); block->set_idx(kRootBlockIndex); + + if (FLAGS_convert_all_blocks) { + GraphToBlock(graph->GetSubGraph(kRootBlockIndex), block); + + VLOG(3) << "Graph to program need convert " << graph->SubGraphsSize() + << " sub graph"; + for (size_t idx = 0; idx < graph->SubGraphsSize(); ++idx) { + // avoid kRootBlockIndex not 0 + if (idx == kRootBlockIndex) continue; + + block = program_pb->add_blocks(); + block->set_idx(idx); + GraphToBlock(graph->GetSubGraph(idx), block); + } + } else { + GraphToBlock(graph, block); + } + + program.CopyFrom(*program_pb); +} + +OpDesc* ReplaceScaleLossGradOp(ir::Node* node, OpDesc* desc) { + desc->SetType("fill_constant"); + desc->SetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName(), + (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss))); + desc->SetAttr("value", 1.0f); + std::vector output_names; + for (auto out : node->outputs) { + output_names.emplace_back(out->Name()); + } + desc->SetOutput("Out", output_names); + return desc; +} + +std::vector* GetGraphOpDesc(const std::vector& nodes, + std::vector* ops) { + for (ir::Node* n : nodes) { + // if node is not Op, skip + if (!n->IsOp()) continue; + + // create fill_constant op + if (n->Name() == "scale_loss_grad") { + ops->emplace_back(); + auto& desc = ops->back(); + ReplaceScaleLossGradOp(n, &desc); + } else if (n->Op()) { + ops->emplace_back(*n->Op()); + } else { + // delete no OpDesc op + } + } + return ops; +} + +void GraphToProgramPass::GraphToBlock(const Graph* graph, + proto::BlockDesc* block) const { + // Remove the unneeded variables after memory optimization. + std::unordered_set vars2remove; + if (graph->Has(kGraphToProgramVarsToRemove)) { + vars2remove = graph->Get>( + kGraphToProgramVarsToRemove); + VLOG(2) << "graph (id: " << block->idx() << ") to program remove " + << vars2remove.size() << " nodes"; + } + block->clear_vars(); std::unordered_set visited_vars; for (ir::Node* n : graph->Nodes()) { if (n->IsVar()) { if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 && - !vars2remove.count(n->Var()->Name())) { + !vars2remove.count(n->Var()->Name()) && + n->GetVarNodeBlockId() == graph->GetBlockId()) { visited_vars.insert(n->Var()->Name()); block->add_vars()->MergeFrom(*n->Var()->Proto()); } @@ -62,16 +132,18 @@ void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const { nodes = TopologyVarientSort( *graph, static_cast(sort_kind)); } else { - nodes = TopologySortOperations(*graph); + if (FLAGS_convert_all_blocks) { + nodes = TopologySortGraphByDescOrder(*graph); + } else { + nodes = TopologySortOperations(*graph); + } } - for (ir::Node* n : nodes) { - if (!n->Op()) continue; - - block->add_ops()->MergeFrom(*n->Op()->Proto()); + std::vector ops; + GetGraphOpDesc(nodes, &ops); + for (auto& op : ops) { + block->add_ops()->MergeFrom(*op.Proto()); } - - program.CopyFrom(*program_pb); } } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.h b/paddle/fluid/framework/ir/graph_to_program_pass.h index 6b17c0076f6..4997c67a92f 100644 --- a/paddle/fluid/framework/ir/graph_to_program_pass.h +++ b/paddle/fluid/framework/ir/graph_to_program_pass.h @@ -29,6 +29,9 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__"; class GraphToProgramPass : public Pass { protected: void ApplyImpl(ir::Graph* graph) const override; + + private: + void GraphToBlock(const Graph* graph, proto::BlockDesc* block) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc index 12119ff56dc..4530faf53dc 100644 --- a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc +++ b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc @@ -14,8 +14,14 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_to_program_pass.h" +#include + #include "gtest/gtest.h" +#include "paddle/fluid/framework/details/build_strategy.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_desc.h" namespace paddle { namespace framework { @@ -103,6 +109,382 @@ TEST(GraphToProgramPass, Basic) { EXPECT_TRUE(vars.find("var2") != vars.end()); EXPECT_TRUE(vars.find("var3") != vars.end()); } + +void BuildProgramWithMultiBlock(ProgramDesc* program) { + auto* global_block = program->MutableBlock(0); + auto* mul_1_x = global_block->Var("Mul_1_X"); + mul_1_x->SetType(proto::VarType::LOD_TENSOR); + mul_1_x->SetLoDLevel(0); + mul_1_x->SetDataType(proto::VarType::FP32); + mul_1_x->SetShape({1000, 784}); + + auto* mul_1_y = global_block->Var("Mul_1_Y"); + mul_1_y->SetType(proto::VarType::LOD_TENSOR); + mul_1_y->SetLoDLevel(0); + mul_1_y->SetDataType(proto::VarType::FP32); + mul_1_y->SetShape({784, 100}); + + auto* mul_1_out = global_block->Var("Mul_1_Out"); + mul_1_out->SetType(proto::VarType::LOD_TENSOR); + auto* mul_op_1 = global_block->AppendOp(); + + mul_op_1->SetType("mul"); + mul_op_1->SetInput("X", {mul_1_x->Name()}); + mul_op_1->SetInput("Y", {mul_1_y->Name()}); + mul_op_1->SetOutput("Y", {mul_1_out->Name()}); + + // building cond op such as less_than + auto* less_than_op_1 = global_block->AppendOp(); + less_than_op_1->SetType("less_than"); + auto* less_than_1_x = global_block->Var("Less_than_1_X"); + less_than_1_x->SetType(proto::VarType::LOD_TENSOR); + less_than_1_x->SetLoDLevel(0); + less_than_1_x->SetDataType(proto::VarType::FP32); + less_than_1_x->SetShape({1}); + + auto* less_than_1_y = global_block->Var("Less_than_1_Y"); + less_than_1_y->SetType(proto::VarType::LOD_TENSOR); + less_than_1_y->SetLoDLevel(0); + less_than_1_y->SetDataType(proto::VarType::FP32); + less_than_1_y->SetShape({1}); + + auto* less_than_1_out = global_block->Var("Less_than_1_Out"); + less_than_1_out->SetType(proto::VarType::BOOL); + + less_than_op_1->SetInput("X", {less_than_1_x->Name()}); + less_than_op_1->SetInput("Y", {less_than_1_y->Name()}); + less_than_op_1->SetOutput("Out", {less_than_1_out->Name()}); + + BlockDesc* sub_block = program->AppendBlock(*global_block); + std::vector sub_blocks; + sub_blocks.push_back(sub_block); + + BlockDesc* sub_block2 = + program->AppendBlock(*sub_block); // for testing nested case. + sub_blocks.push_back(sub_block2); + + // building while op in sub_block + auto* while_op = global_block->AppendOp(); + while_op->SetType("while"); + while_op->SetAttr("sub_block", sub_blocks[0]); + + auto* while_x = global_block->Var("While_X"); + while_x->SetType(proto::VarType::LOD_TENSOR); + while_x->SetLoDLevel(0); + while_x->SetDataType(proto::VarType::FP32); + while_x->SetShape({1}); + + while_op->SetInput("kX", {while_x->Name()}); + while_op->SetInput("kCondition", {less_than_1_out->Name()}); + + auto* while_out = global_block->Var("While_Out"); + while_out->SetType(proto::VarType::LOD_TENSOR); + while_out->SetLoDLevel(0); + while_out->SetDataType(proto::VarType::FP32); + while_out->SetShape({1}); + + auto* steps = global_block->Var("StepScopes"); + + while_op->SetOutput("kOutputs", {while_out->Name()}); + while_op->SetOutput("kStepScopes", {steps->Name()}); + + auto* mul_2_x = global_block->Var("Mul_2_X"); + mul_2_x->SetType(proto::VarType::LOD_TENSOR); + mul_2_x->SetLoDLevel(0); + mul_2_x->SetDataType(proto::VarType::FP32); + mul_2_x->SetShape({1000, 784}); + + auto* mul_2_y = global_block->Var("Mul_2_Y"); + mul_2_y->SetType(proto::VarType::LOD_TENSOR); + mul_2_y->SetLoDLevel(0); + mul_2_y->SetDataType(proto::VarType::FP32); + mul_2_y->SetShape({784, 100}); + + auto* mul_op_2 = sub_blocks[0]->AppendOp(); + mul_op_2->SetType("mul"); + mul_op_2->SetInput("X", {mul_2_x->Name()}); + mul_op_2->SetInput("Y", {mul_2_y->Name()}); + + auto* mul_2_out = global_block->Var("Mul_2_Out"); + mul_2_out->SetType(proto::VarType::LOD_TENSOR); + mul_op_2->SetOutput("Y", {mul_2_out->Name()}); + + auto* less_than_op_2 = sub_blocks[0]->AppendOp(); + less_than_op_2->SetType("less_than"); + auto* less_than_2_x = global_block->Var("Less_than_2_X"); + less_than_2_x->SetType(proto::VarType::LOD_TENSOR); + less_than_2_x->SetLoDLevel(0); + less_than_2_x->SetDataType(proto::VarType::FP32); + less_than_2_x->SetShape({1}); + + auto* less_than_2_y = global_block->Var("Less_than_2_Y"); + less_than_2_y->SetType(proto::VarType::LOD_TENSOR); + less_than_2_y->SetLoDLevel(0); + less_than_2_y->SetDataType(proto::VarType::FP32); + less_than_2_y->SetShape({1}); + + less_than_op_2->SetInput("X", {less_than_2_x->Name()}); + less_than_op_2->SetInput("Y", {less_than_2_y->Name()}); + + auto* less_than_2_out = global_block->Var("Less_than_2_Out"); + less_than_2_out->SetType(proto::VarType::BOOL); + less_than_op_2->SetOutput("Out", {less_than_2_out->Name()}); + + auto* cond_op = sub_blocks[0]->AppendOp(); + cond_op->SetType("conditional_block"); + cond_op->SetAttr("sub_block", sub_blocks[1]); + + auto* cond_x = sub_blocks[0]->Var("Cond_X"); + cond_x->SetType(proto::VarType::LOD_TENSOR); + cond_x->SetLoDLevel(0); + cond_x->SetDataType(proto::VarType::FP32); + cond_x->SetShape({1}); + + cond_op->SetInput("kInputs", {cond_x->Name()}); + cond_op->SetInput("kCondition", {less_than_2_out->Name()}); + + auto* cond_out = sub_blocks[0]->Var("Cond_Out"); + cond_out->SetType(proto::VarType::LOD_TENSOR); + cond_out->SetLoDLevel(0); + cond_out->SetDataType(proto::VarType::FP32); + cond_out->SetShape({1}); + + auto* scope = sub_blocks[0]->Var("Scope"); + scope->SetType(proto::VarType::STEP_SCOPES); + + cond_op->SetOutput("kOutputs", {cond_out->Name()}); + cond_op->SetOutput("kScope", {scope->Name()}); + + auto* mul_3_x = global_block->Var("Mul_3_X"); + mul_3_x->SetType(proto::VarType::LOD_TENSOR); + mul_3_x->SetLoDLevel(0); + mul_3_x->SetDataType(proto::VarType::FP32); + mul_3_x->SetShape({1000, 784}); + + auto* mul_3_y = global_block->Var("Mul_3_Y"); + mul_3_y->SetType(proto::VarType::LOD_TENSOR); + mul_3_y->SetLoDLevel(0); + mul_3_y->SetDataType(proto::VarType::FP32); + mul_3_y->SetShape({784, 100}); + + auto* mul_3_out = global_block->Var("Mul_3_Out"); + mul_3_out->SetType(proto::VarType::LOD_TENSOR); + + auto* mul_op_3 = sub_blocks[1]->AppendOp(); + mul_op_3->SetType("mul"); + mul_op_3->SetInput("X", {mul_3_x->Name()}); + mul_op_3->SetInput("Y", {mul_3_y->Name()}); + mul_op_3->SetOutput("Y", {mul_3_out->Name()}); +} + +bool VarComparator(const VarDesc* a, const VarDesc* b) { + return a->Name() < b->Name(); +} + +void CheckBlockVarsEqual(const BlockDesc& before_block, + const BlockDesc& after_block) { + auto before_vars = before_block.AllVars(); + auto after_vars = after_block.AllVars(); + + EXPECT_EQ(before_vars.size(), after_vars.size()); + + // var's order is unimportant + std::sort(before_vars.begin(), before_vars.end(), VarComparator); + std::sort(after_vars.begin(), after_vars.end(), VarComparator); + + for (size_t var_idx = 0; var_idx < before_vars.size(); ++var_idx) { + const auto& before_var = before_vars.at(var_idx); + const auto& after_var = after_vars.at(var_idx); + + EXPECT_EQ(before_var->Name(), after_var->Name()); + EXPECT_EQ(before_var->GetType(), after_var->GetType()); + } +} + +void CheckOpInputsEqual(const OpDesc* before_op, const OpDesc* after_op) { + const auto& before_inputs = before_op->InputNames(); + const auto& after_inputs = after_op->InputNames(); + + EXPECT_EQ(before_inputs.size(), after_inputs.size()); + for (size_t in_idx = 0; in_idx < before_inputs.size(); ++in_idx) { + const auto& before_in_arg = before_inputs[in_idx]; + const auto& after_in_arg = after_inputs[in_idx]; + EXPECT_EQ(before_in_arg, after_in_arg); + + const auto& before_in_vars = before_op->Input(before_in_arg); + const auto& after_in_vars = after_op->Input(after_in_arg); + EXPECT_EQ(before_in_vars, after_in_vars); + } +} + +void CheckOpOutputsEqual(const OpDesc* before_op, const OpDesc* after_op) { + const auto& before_outputs = before_op->OutputNames(); + const auto& after_outputs = after_op->OutputNames(); + + EXPECT_EQ(before_outputs.size(), after_outputs.size()); + for (size_t out_idx = 0; out_idx < before_outputs.size(); ++out_idx) { + const auto& before_out_arg = before_outputs[out_idx]; + const auto& after_out_arg = after_outputs[out_idx]; + EXPECT_EQ(before_out_arg, after_out_arg); + + const auto& before_out_vars = before_op->Output(before_out_arg); + const auto& after_out_vars = after_op->Output(after_out_arg); + EXPECT_EQ(before_out_vars, after_out_vars); + } +} + +void CheckOpAttrsEqual(const OpDesc* before_op, const OpDesc* after_op) { + const auto& before_attrs = before_op->AttrNames(); + const auto& after_attrs = after_op->AttrNames(); + + EXPECT_EQ(before_attrs.size(), after_attrs.size()); + for (size_t attr_idx = 0; attr_idx < before_attrs.size(); ++attr_idx) { + const auto& before_attr = before_attrs[attr_idx]; + const auto& after_attr = after_attrs[attr_idx]; + EXPECT_EQ(before_attr, after_attr); + + EXPECT_EQ(before_op->GetAttrType(before_attr), + after_op->GetAttrType(after_attr)); + } +} + +void CheckBlockOpsEqual(const BlockDesc& before_block, + const BlockDesc& after_block) { + EXPECT_EQ(before_block.OpSize(), after_block.OpSize()); + + // op's order must be the same + for (size_t op_idx = 0; op_idx < before_block.OpSize(); ++op_idx) { + const auto& before_op = before_block.Op(op_idx); + const auto& after_op = after_block.Op(op_idx); + + EXPECT_EQ(before_op->Type(), after_op->Type()); + + // Step4.2.1 : check each op's input + CheckOpInputsEqual(before_op, after_op); + + // Step4.2.2 : check each op's output + CheckOpOutputsEqual(before_op, after_op); + + // Step4.2.3 : check each op's attribute + CheckOpAttrsEqual(before_op, after_op); + } +} + +TEST(GraphToProgramPass, MultiBlock) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + // Step1: Build a program with multi block + ProgramDesc before_prog; + BuildProgramWithMultiBlock(&before_prog); + + // Step2: Convert program into graph + std::unique_ptr g(new ir::Graph(before_prog)); + + // Step3 : Convert graph back to program + auto pass = paddle::framework::ir::PassRegistry::Instance().Get( + "graph_to_program_pass"); + + ProgramDesc after_prog; + pass->SetNotOwned("program", &after_prog); + pass->Apply(g.get()); + + // Step4 : Check tow program equal + EXPECT_EQ(before_prog.Size(), after_prog.Size()); + + for (size_t block_idx = 0; block_idx < before_prog.Size(); ++block_idx) { + const auto& before_block = before_prog.Block(block_idx); + const auto& after_block = after_prog.Block(block_idx); + + EXPECT_EQ(before_block.ID(), after_block.ID()); + + // Step4.1 : check each block's var + CheckBlockVarsEqual(before_block, after_block); + + // Step4.2 : check each block's op + CheckBlockOpsEqual(before_block, after_block); + } + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + +void BuildProgramWithScaleLossGrad(Graph* g) { + OpDesc op1; + op1.SetType("op1"); + OpDesc op2; + op2.SetType("op2"); + OpDesc op3; + op3.SetType("op3"); + OpDesc op4; + op4.SetType("op4"); + VarDesc var1("var1"); + VarDesc var2("var2"); + + ir::Node* o1 = g->CreateOpNode(&op1); + ir::Node* o2 = g->CreateOpNode(&op2); + ir::Node* o3 = + g->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation); + ir::Node* o4 = + g->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation); + ir::Node* v1 = g->CreateVarNode(&var1); + ir::Node* v2 = g->CreateVarNode(&var2); + + // o1->v1->o2 + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + // o3->v1 + o3->outputs.push_back(v1); + v1->inputs.push_back(o1); + v1->inputs.push_back(o3); + // o4->v2 + o4->outputs.push_back(v2); + v2->inputs.push_back(o4); +} + +TEST(GraphToProgramPass, ReplaceScaleLossGrad) { + // Step1: Build a program with multi block + ProgramDesc before_prog; + Graph before_graph(before_prog); + BuildProgramWithScaleLossGrad(&before_graph); + + // Step2 : Convert graph back to program + auto pass = paddle::framework::ir::PassRegistry::Instance().Get( + "graph_to_program_pass"); + + ProgramDesc after_prog; + pass->SetNotOwned("program", &after_prog); + pass->Apply(&before_graph); + + // Step3 : statistics scale_loss_grad and fill_constant number + int scale_node_num = 0, fill_node_num = 0; + const auto& before_nodes_set = before_graph.Nodes(); + for (const auto& n : before_nodes_set) { + if (n->Name() == "scale_loss_grad") { + ++scale_node_num; + } else if (n->Name() == "fill_constant") { + ++fill_node_num; + } + } + + int scale_op_num = 0, fill_op_num = 0; + const auto& block = after_prog.Block(0); + for (const auto& op : block.AllOps()) { + if (op->Type() == "fill_constant") { + ++fill_op_num; + } else if (op->Type() == "scale_loss_grad") { + ++scale_op_num; + } + } + + // Check pass OK + EXPECT_EQ(scale_op_num, 0); + EXPECT_EQ(scale_node_num + fill_node_num, fill_op_num); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc index 6755cf0b275..c9fdfafe4c4 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc @@ -26,6 +26,13 @@ using OpVariant = operators::OpVariant; class WhileOpEagerDeletionPass : public ir::Pass { protected: void ApplyImpl(ir::Graph *graph) const override { + if (!graph->IsMainGraph()) { + // TODO(zhhsplendid): the WhileOpEagerDeletionPass is based on old Graph, + // which only applies to the main block graph. The new Eager Deletion + // Technical can be added after we write new while_op based on SubGraph + // instead of SubBlock + return; + } auto all_ops = ir::FilterByNodeWrapper(*graph); // Find all while_op and while_grad_op. In case of @to_static, graph @@ -47,6 +54,7 @@ class WhileOpEagerDeletionPass : public ir::Pass { } } if (graph->IsConstructedByPartialProgram()) { + VLOG(4) << "Is Paritial Program"; PADDLE_ENFORCE_LE( target_ops.size(), 1, platform::errors::InvalidArgument( @@ -69,8 +77,11 @@ class WhileOpEagerDeletionPass : public ir::Pass { } for (auto &ops_pair : target_ops) { + VLOG(4) << "Scope Idx = " << ops_pair.first; auto &while_ops = ops_pair.second.first; + VLOG(4) << "while_ops.size() = " << while_ops.size(); auto &while_grad_ops = ops_pair.second.second; + VLOG(4) << "while_grad_ops.size() = " << while_grad_ops.size(); operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( graph->OriginProgram(), while_ops, while_grad_ops); } diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index 7143c9a7a3e..51542fc8085 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -30,7 +30,7 @@ std::unique_ptr CreateNodeForTest(const std::string &name, } std::unique_ptr CreateNodeForTest(VarDesc *var_desc) { - return std::unique_ptr(new Node(var_desc)); + return std::unique_ptr(new Node(var_desc, 0)); } std::unique_ptr CreateNodeForTest(OpDesc *op_desc) { diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index d0db3bd36e1..d0568f39ef6 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -136,9 +136,98 @@ class Node { var_desc_->SetName(new_name); } + int DescOrder() const { return desc_order_; } + + int GetVarNodeBlockId() const { + PADDLE_ENFORCE_EQ( + type_ == Type::kVariable && var_desc_, true, + platform::errors::InvalidArgument("Node must be type of variable.")); + return block_id_; + } + + const std::string ToString() const { + if (IsOp()) { + std::string op_str(Name()); + + const auto& op = Op(); + if (op == nullptr) { + // Node is an Op but hasn't OpDesc (often create by CreateEmptyNode), + // like ScaleLossGradOp, it's type is OpHandle, which created by Pass + // and then inserted into graph. + // For OpHandle, we have to use Node's input and output for sorting. + std::vector sorted_inputs(inputs); + std::vector sorted_outputs(outputs); + + auto comparator = [](Node* a, Node* b) { + return a->Name() > b->Name(); + }; + std::stable_sort(sorted_inputs.begin(), sorted_inputs.end(), + comparator); + std::stable_sort(sorted_outputs.begin(), sorted_outputs.end(), + comparator); + + std::string out_str = "{"; + std::string pre_str = ""; + for (const auto& output : sorted_outputs) { + out_str.append(pre_str + output->Name()); + pre_str = ", "; + } + out_str.append("} = "); + + std::string in_str = "("; + pre_str = ""; + for (const auto& input : sorted_inputs) { + in_str.append(pre_str + input->Name()); + pre_str = ", "; + } + in_str.append(")"); + op_str = out_str + op_str + in_str; + } else { + // A normal Op, has OpDesc, create from ProgramDesc + std::string out_str = "{"; + std::string outer_pre_str = ""; + for (const auto& output : op->OutputNames()) { + out_str.append(outer_pre_str + output + "=["); + std::string inner_pre_str = ""; + for (const auto& arg : op->Output(output)) { + out_str.append(inner_pre_str + arg); + inner_pre_str = " ,"; + } + outer_pre_str = ", "; + out_str.append("]"); + } + out_str.append("} = "); + + std::string in_str = "("; + outer_pre_str = ""; + for (const auto& input : op->InputNames()) { + in_str.append(outer_pre_str + input + "=["); + std::string inner_pre_str = ""; + for (const auto& arg : op->Input(input)) { + in_str.append(inner_pre_str + arg); + inner_pre_str = " ,"; + } + outer_pre_str = " ,"; + in_str.append("]"); + } + in_str.append(")"); + op_str = out_str + op_str + in_str; + } + + return op_str; + } + return Name(); + } + std::vector inputs; std::vector outputs; + // Because NO_DESC_ORDER is a constexpr number, + // no one can change it, meanwhile, we need + // check whether the DescOrder invalid sometime, + // so expose it is a good idea + static constexpr int NO_DESC_ORDER = INT_MAX; + protected: std::string name_; std::unique_ptr var_desc_; @@ -146,30 +235,45 @@ class Node { Type type_; int id_; + int desc_order_; + int block_id_{-1}; + private: // ID can only set by a Graph. void SetId(int id) { id_ = id; } + // desc_order can only set by a Graph when constructing a Graph from a + // BlockDesc. + void SetDescOrder(int desc_order) { desc_order_ = desc_order; } + friend class Graph; friend std::unique_ptr CreateNodeForTest(const std::string& name, Node::Type type); friend std::unique_ptr CreateNodeForTest(VarDesc* var_desc); friend std::unique_ptr CreateNodeForTest(OpDesc* op_desc); - explicit Node(const std::string& name, Type type) - : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} + explicit Node(const std::string& name, Type type, int block_id = 0) + : name_(name), + var_desc_(nullptr), + op_desc_(nullptr), + type_(type), + desc_order_(NO_DESC_ORDER), + block_id_(block_id) {} - explicit Node(VarDesc* var_desc) + explicit Node(VarDesc* var_desc, int block_id) : name_(var_desc->Name()), var_desc_(new VarDesc(*var_desc)), op_desc_(nullptr), - type_(Type::kVariable) {} + type_(Type::kVariable), + desc_order_(NO_DESC_ORDER), + block_id_(block_id) {} explicit Node(OpDesc* op_desc) : name_(op_desc->Type()), var_desc_(nullptr), op_desc_(new OpDesc(*op_desc, op_desc->Block())), - type_(Type::kOperation) {} + type_(Type::kOperation), + desc_order_(NO_DESC_ORDER) {} Node() = delete; diff --git a/paddle/fluid/framework/ir/node_test.cc b/paddle/fluid/framework/ir/node_test.cc index 73f5b6619c1..9c47df402bd 100644 --- a/paddle/fluid/framework/ir/node_test.cc +++ b/paddle/fluid/framework/ir/node_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/node.h" #include "gtest/gtest.h" +#include "paddle/fluid/framework/var_desc.h" namespace paddle { namespace framework { @@ -75,6 +76,32 @@ TEST(NodeTest, Basic) { EXPECT_FALSE(alive2); } +TEST(NodeTest, ToString) { + VarDesc var_desc("n2"); + OpDesc op_desc; + op_desc.SetType("test_op"); + op_desc.SetInput("X", {"x1", "x2", "x3"}); + op_desc.SetOutput("Y", {"y1", "y2"}); + + std::unique_ptr n1(CreateNodeForTest("n1", Node::Type::kVariable)); + std::unique_ptr n2(CreateNodeForTest(&var_desc)); + std::unique_ptr n3(CreateNodeForTest("n3", Node::Type::kOperation)); + std::unique_ptr n4(CreateNodeForTest(&op_desc)); + + EXPECT_EQ(n1->ToString(), "n1"); + EXPECT_EQ(n2->ToString(), "n2"); + + EXPECT_EQ(n3->Op(), nullptr); + EXPECT_EQ(n3->ToString(), "{} = n3()"); + EXPECT_NE(n4->Op(), nullptr); + EXPECT_EQ(n4->ToString(), "{Y=[y1 ,y2]} = test_op(X=[x1 ,x2 ,x3])"); + + n3->inputs.push_back(n1.get()); + n3->outputs.push_back(n2.get()); + EXPECT_EQ(n3->Op(), nullptr); + EXPECT_EQ(n3->ToString(), "{n2} = n3(n1)"); +} + } // namespace ir } // namespace framework } // namespace paddle -- GitLab