/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/graph_helper.h" #include #include #include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/program_utils.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/framework/details/nccl_op_handle.h" #include "paddle/fluid/platform/collective_helper.h" #endif DECLARE_bool(convert_all_blocks); PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir, "", "FLAGS_print_sub_graph_dir is used " "to print the nodes of sub_graphs."); namespace paddle { namespace framework { namespace ir { namespace { template void SortHelper(const std::map, NodeComparator> &adj_list, ir::Node *node, std::unordered_set *visited, std::vector *ret) { visited->insert(node); for (auto adj : adj_list.at(node)) { if (visited->find(adj) == visited->end()) { SortHelper(adj_list, adj, visited, ret); } } VLOG(5) << "topology sort insert: " << node->Name() << " " << reinterpret_cast(node) << " input " << node->inputs.size(); ret->push_back(node); } template bool HasCircleHelper(ir::Node *node, const std::map, NodeComparator> &adj_list, std::unordered_set *visited, std::unordered_set *in_trace, std::vector> *circles) { if (visited->find(node) == visited->end()) { visited->insert(node); in_trace->insert(node); for (ir::Node *in : adj_list.at(node)) { if (visited->find(in) == visited->end() && HasCircleHelper( in, adj_list, visited, in_trace, circles)) { return true; } else if (in_trace->find(in) != in_trace->end()) { if (circles != nullptr) { std::vector circle; circle.emplace_back(in); ir::Node *p = in; for (auto &adj : adj_list.at(p)) { if (in_trace->count(adj)) { circle.emplace_back(adj); p = adj; } } circles->emplace_back(circle); } return true; } } } in_trace->erase(node); return false; } template bool HasCircleInternal(const std::map, NodeComparator> &adj_list, std::vector> *circles) { std::unordered_set visited; std::unordered_set in_trace; for (auto &adj : adj_list) { if (HasCircleHelper( adj.first, adj_list, &visited, &in_trace, circles)) { return true; } } return false; } } // namespace bool HasCircle(const Graph &graph) { return HasCircleInternal(BuildOperationAdjList(graph), nullptr); } bool VarDescIsConsistency(const Graph &graph) { std::unordered_map> var_name2node_set; for (ir::Node *node : graph.Nodes()) { if (node->IsVar() && node->Var()) { var_name2node_set[node->Var()->Name()].emplace(node); } } for (auto &iter : var_name2node_set) { auto &first_node = *iter.second.begin(); bool is_persistable = std::any_of(iter.second.begin(), iter.second.end(), [&first_node](const ir::Node *node) { return node->Var()->Persistable(); }); if (is_persistable) { bool is_consistency = std::all_of(iter.second.begin(), iter.second.end(), [&first_node](const ir::Node *node) { return *node->Var() == *first_node->Var(); }); if (!is_consistency) return false; } } return true; } bool FindCircleSubGraph(const Graph &graph, std::vector> *circles) { return HasCircleInternal(BuildOperationAdjList(graph), circles); } std::vector TopologySortOperations(const Graph &graph) { std::map, ir::NodeComp> adj_list = BuildOperationAdjList(graph); PADDLE_ENFORCE_EQ(HasCircleInternal(adj_list, nullptr), false, platform::errors::InvalidArgument( "Generated graph shouldn't contain cycle.")); std::unordered_set visited; std::vector ret; for (auto adj : adj_list) { if (visited.find(adj.first) == visited.end()) { SortHelper(adj_list, adj.first, &visited, &ret); } } return ret; } bool IsTopologySortOperationsUnique(const Graph &graph) { auto nodes = TopologySortOperations(graph); size_t n = nodes.size(); for (size_t i = 1; i < n; ++i) { auto *prev_op = nodes[i - 1]; auto *cur_op = nodes[i]; std::unordered_set prev_op_outputs; for (auto *output : prev_op->outputs) { prev_op_outputs.insert(output); } bool found = false; for (auto *input : cur_op->inputs) { if (prev_op_outputs.count(input) > 0) { found = true; break; } } if (!found) { return false; } } return true; } // Build operator outlink edge table. std::map> BuildOperationOutAdjList( const Graph &graph) { std::map> adj_list; for (auto &n : graph.Nodes()) { if (!n->IsOp()) continue; if (adj_list.find(n) == adj_list.end()) { adj_list[n] = std::unordered_set(); } for (auto &var : n->outputs) { for (auto &adj_n : var->outputs) { PADDLE_ENFORCE_EQ(adj_n->NodeType(), ir::Node::Type::kOperation, platform::errors::InvalidArgument( "Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(), static_cast(adj_n->NodeType()))); VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) << " -> " << n->Name() << reinterpret_cast(n) << " via " << var->Name() << reinterpret_cast(var); adj_list[n].insert(adj_n); } } } return adj_list; } std::vector OpDFSSort(const Graph &graph) { auto edge_table = BuildOperationOutAdjList(graph); std::stack stack; for (auto &ele : edge_table) { if (ele.first->inputs.empty()) { // find the input ops (those without input vars) stack.push(ele.first); } else { // find the ops with only persistable vars as inputs. bool all_persistable = true; for (auto *input : ele.first->inputs) { if (!(input->IsVar() && input->Var() && input->Var()->Persistable())) { all_persistable = false; } } if (all_persistable) { stack.push(ele.first); } } } std::vector res; // start from the feed op and DFS std::unordered_set unique_set; while (!stack.empty()) { // will start from the last feed by default. auto cur = stack.top(); stack.pop(); unique_set.insert(cur); res.push_back(cur); for (auto *op : edge_table[cur]) { if (!unique_set.count(op)) { stack.push(op); } } } return res; } std::vector TopologyDfsSortOperations(const Graph &graph) { std::vector nodes; std::unordered_map in_degree; auto set_out_ops_ready = [&](Node *var) { for (auto *op : var->outputs) { --in_degree[op]; } }; // build in_degree for (auto *node : graph.Nodes()) { if (node->IsOp()) { in_degree[node] += node->inputs.size(); } else if (node->IsVar() && node->inputs.empty()) { // put all the inputs of the whole graph ready. set_out_ops_ready(node); } } std::deque op_queue; // first visit for (auto &node : OpDFSSort(graph)) { if (node->IsOp()) { op_queue.push_back(node); } } // traverse the graph int num_ops = op_queue.size(); while (num_ops) { for (auto it = op_queue.begin(); it != op_queue.end(); it++) { auto *&cur_op = *it; if (!cur_op || in_degree[cur_op] > 0) continue; // visit this node // put all the output var of this op valid. for (auto *out_var : cur_op->outputs) { if (!out_var) continue; set_out_ops_ready(out_var); } VLOG(8) << "visit " << cur_op->Name(); nodes.push_back(cur_op); cur_op = nullptr; num_ops--; } } return nodes; } size_t GraphNum(const Graph &graph) { std::unordered_set nodes(graph.Nodes()); std::unordered_set visited_nodes; visited_nodes.reserve(nodes.size()); std::deque q_nodes; std::vector> graph_nodes; std::unordered_set g_nodes; // q_set used to record records in the queue. std::unordered_set q_set; size_t graph_count = 0; auto traverse_nodes = [&visited_nodes, &q_nodes, &q_set](const std::vector &nodes) { for (auto n : nodes) { if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) { q_nodes.push_back(n); q_set.insert(n); } } }; while (visited_nodes.size() != nodes.size()) { if (!q_nodes.empty()) { auto cur_node = q_nodes.front(); q_nodes.pop_front(); q_set.erase(cur_node); visited_nodes.insert(cur_node); g_nodes.insert(cur_node); traverse_nodes(cur_node->inputs); traverse_nodes(cur_node->outputs); } else { ++graph_count; if (g_nodes.size()) { graph_nodes.emplace_back(g_nodes); } g_nodes.clear(); for (auto &n : nodes) { if (visited_nodes.count(n) == 0) { q_nodes.push_back(n); q_set.insert(n); break; } } } } if (g_nodes.size()) { graph_nodes.emplace_back(g_nodes); } if (FLAGS_print_sub_graph_dir.size()) { if (graph_nodes.size() > 1) { std::stringstream out; for (auto &g_n : graph_nodes) { out << "graph_nodes: " << g_n.size() << "\n"; } out << "\n\n"; for (auto &g_n : graph_nodes) { out << "graph_nodes: " << g_n.size(); for (auto &node : g_n) { out << "\nNode: " << node->Name() << " in ["; for (auto &n : node->inputs) { out << n->Name() << ", "; } out << "], out["; for (auto &n : node->outputs) { out << n->Name() << ", "; } out << "]"; } out << "\n\n\n"; } std::unique_ptr fout( new std::ofstream(FLAGS_print_sub_graph_dir)); PADDLE_ENFORCE_EQ(fout->good(), true, platform::errors::Unavailable( "Can not open file %s for printing the graph.", FLAGS_print_sub_graph_dir)); *fout << out.str(); } } return graph_count; } void CleanIndividualNodes(Graph *graph) { std::unordered_set nodes2rm; for (auto *node : graph->Nodes()) { if (node->inputs.empty() && node->outputs.empty()) { nodes2rm.insert(node); } } for (auto *node : nodes2rm) { graph->RemoveNode(node); } } std::vector TopologyVarientSort(const Graph &graph, SortKind sort_kind) { switch (sort_kind) { case SortKind::TS: return framework::ir::TopologySortOperations(graph); default: return framework::ir::TopologyDfsSortOperations(graph); } } class DescOrderComparator { public: bool operator()(Node *const &n1, Node *const &n2) const { if (n1->DescOrder() < n2->DescOrder()) { return true; } else if (n1->DescOrder() == n2->DescOrder()) { return n1->id() < n2->id() || (n1->id() == n2->id() && n1->ToString() < n2->ToString()); } return false; } }; std::vector TopologySortGraphByDescOrder(const Graph &graph) { std::map, DescOrderComparator> adj_list = BuildOperationAdjList(graph); PADDLE_ENFORCE_EQ(HasCircleInternal(adj_list, nullptr), false, platform::errors::InvalidArgument( "Generated graph shouldn't contain cycle.")); std::unordered_set visited; std::vector ret; for (auto adj : adj_list) { if (visited.find(adj.first) == visited.end()) { SortHelper(adj_list, adj.first, &visited, &ret); } } return ret; } void RemoveControlDepInputAndOuput(OpDesc *op_desc) { auto remove_control_dep_var = [](VariableNameMap *var_name_map) { for (auto &pair : *var_name_map) { std::vector &var_names = pair.second; auto it = var_names.begin(); while (it != var_names.end()) { if (it->find(ir::Node::kControlDepVarName) != std::string::npos) { it = var_names.erase(it); VLOG(6) << "Remove var " << *it; } else { ++it; } } } }; remove_control_dep_var(op_desc->MutableInputs()); remove_control_dep_var(op_desc->MutableOutputs()); op_desc->Flush(); } static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { desc->SetType("fill_constant"); desc->SetAttr("shape", std::vector({1})); desc->SetAttr("value", 1.0f); if (node.IsWrappedBy()) { details::OpHandleBase &op_hander = const_cast(&node)->Wrapper(); desc->SetAttr( "dtype", dynamic_cast(&op_hander)->DType()); desc->SetAttr( "value", dynamic_cast(&op_hander)->Coeff()); } desc->SetAttr("force_cpu", false); desc->SetAttr( OpProtoAndCheckerMaker::OpRoleAttrName(), (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss))); // TODO(Ruibiao) : Set OpDeviceAttrName when needed std::vector output_names; for (auto out : node.outputs) { output_names.emplace_back(out->Name()); } desc->SetOutput("Out", output_names); return desc; } void ReplaceAllReduceOp(const Node &node, proto::BlockDesc *block, std::vector *ops) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) bool is_fused = (node.Name() == "fused_all_reduce"); details::OpHandleBase &op_handle = const_cast(&node)->Wrapper(); auto &in_var_handles = op_handle.Inputs(); // Even if PADDLE_WITH_NCCL is defined, if the program runs on CPU, // nccl_ctxs_ in NCCLOpHandleBase will be nullptr, and calling the // GetComm() method will report an error. // There is bugs in all_reduce_op_handle method on CPU devices, skip // this case in temporary. if (dynamic_cast(&op_handle)->GetNcclContext() == nullptr) { VLOG(4) << "Skip replacing allreduce op because nccl_ctxs_ is nullptr."; return; } std::string all_reduce_var_name; // If fused, add check_memory_continue OP to fuse inputs if (is_fused) { all_reduce_var_name = "fake_coalesce_" + std::to_string(ops->size()); proto::VarDesc var_desc; var_desc.set_name(all_reduce_var_name); var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR); block->mutable_vars()->Add()->CopyFrom(var_desc); VLOG(4) << "add variable for check_memory_continue: " << all_reduce_var_name; // get inputs of check_memory_continue std::vector in_names; for (const auto &in : in_var_handles) { if (dynamic_cast(in) != nullptr) { continue; } in_names.emplace_back(in->Name()); } ops->emplace_back(); OpDesc &fuse_op_desc = ops->back(); fuse_op_desc.SetType("check_memory_continue"); fuse_op_desc.SetInput("X", in_names); fuse_op_desc.SetOutput("Out", {all_reduce_var_name}); fuse_op_desc.SetOutput("XOut", in_names); fuse_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), (static_cast(OpRole::kBackward))); } else { all_reduce_var_name = in_var_handles[0]->Name(); } // add c_allreduce_sum OP ops->emplace_back(); OpDesc &all_reduce_op_desc = ops->back(); all_reduce_op_desc.SetType("c_allreduce_sum"); all_reduce_op_desc.SetInput("X", {all_reduce_var_name}); all_reduce_op_desc.SetOutput("Out", {all_reduce_var_name}); int ring_id = platform::NCCLCommContext::Instance().GetRingId( dynamic_cast(&op_handle)->GetComm()); all_reduce_op_desc.SetAttr("ring_id", ring_id); all_reduce_op_desc.SetAttr("use_calc_stream", false); all_reduce_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), (static_cast(OpRole::kBackward))); // handle grad merge if (dynamic_cast(&op_handle)) { VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum"; const std::string cond_name = dynamic_cast(&op_handle) ->GradMergeCondName(); all_reduce_op_desc.SetInput("Cond", {cond_name}); } else if (dynamic_cast(&op_handle)) { VLOG(4) << "GradMergeAllReduceOpHandle: add cond to c_allreduce_sum"; const std::string cond_name = dynamic_cast(&op_handle) ->GradMergeCondName(); all_reduce_op_desc.SetInput("Cond", {cond_name}); } // Add dependency for FusedAllReduce. // For the following example: // ### fused_grad = FusedAllReduce(grad0, grad1, grad2, ...) // ### v0 = op0(grad0) // ### v1 = op1(grad1) // It is converted to: // ### fused_grad = check_memory_continue(grad0, grad1, grad2, ...) // ### fused_grad = c_sum_allreduce(fused_grad) // ### v0 = op0(grad0) // ### v1 = op1(grad1) // We should add the following dependency to ensure that op0 and op1 both run // afer c_sum_allreduce: // ### grad0 = depend(grad0, fused_grad) // ### grad1 = depend(grad1, fused_grad) if (is_fused) { for (const auto &in : in_var_handles) { if (dynamic_cast(in) != nullptr) { continue; } ops->emplace_back(); OpDesc &depend_op_desc = ops->back(); depend_op_desc.SetType("depend"); depend_op_desc.SetInput("X", {in->Name()}); depend_op_desc.SetInput("Dep", {all_reduce_var_name}); depend_op_desc.SetOutput("Out", {in->Name()}); } } #else PADDLE_THROW( platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented " "for paddle compiled with NCCL/RCCL.")); #endif } void UpdateControlOpSkipEagerDeletionVars(const Node &node, const Graph &graph, const size_t graph_idx, const std::string &control_type) { // Node(zhangbo): SkipEagerDeletionVars pass policy for control flow class op: // 1) if op is in main_block: SkipEagerDeletionVars information will be // writted into Graph OpNode which wrapped by OpHandleBase; 2) if op is in // sub_block: SkipEagerDeletionVars information will be writted into graph's // OriginProgram OpDesc. Please refer to // FindAllConditionalBlockAndConditionalBlockGradOp in // "paddle/fluid/operators/controlflow/conditional_block_op_helper.cc" if (graph_idx != 0) { auto origin_program = graph.OriginProgram(); auto &block = origin_program.Block(graph_idx); for (size_t j = 0; j < block.OpSize(); ++j) { auto *op = block.Op(j); if (op->Type() == control_type && op->HasAttr("skip_eager_deletion_vars")) { if (op->InputArgumentNames() == node.Op()->InputArgumentNames() && op->OutputArgumentNames() == node.Op()->OutputArgumentNames()) { node.Op()->SetAttr("skip_eager_deletion_vars", op->GetAttr("skip_eager_deletion_vars")); } } } } } static void GetGraphOpDesc(const std::vector &nodes, proto::BlockDesc *block, std::vector *ops, const Graph &graph, const size_t graph_idx) { auto is_fused_opt = [](Node *n) -> bool { auto op_type = n->Op()->Type(); auto is_opt = (op_type == "adam" || op_type == "momentum" || op_type == "sgd"); auto input_names = n->Op()->InputArgumentNames(); auto contains_fused_var = std::any_of( input_names.begin(), input_names.end(), [](std::string name) { return name.find(details::kFusedVarNamePrefix) != std::string::npos; }); VLOG(4) << is_opt << " " << contains_fused_var; return is_opt && contains_fused_var; }; for (Node *n : nodes) { // if node is not Op, skip if (!n->IsOp()) continue; // create fill_constant op if (n->Name() == "scale_loss_grad") { VLOG(4) << "convert op node scale_loss_grad to desc fill_constant"; ops->emplace_back(); auto &desc = ops->back(); ReplaceScaleLossGradOp(*n, &desc); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) } else if ((n->Name() == "allreduce" || n->Name() == "fused_all_reduce") && dynamic_cast( &(n->Wrapper())) != nullptr) { VLOG(4) << "convert op node " << n->Name() << " to desc c_allreduce_sum"; ReplaceAllReduceOp(*n, block, ops); VLOG(4) << n->ToString(); #endif } else if (n->Op()) { VLOG(4) << "convert op node to desc " << n->Op()->Type(); if (is_fused_opt(n)) { OpDesc depend_desc(n->Op()->Block()); std::vector deps; for (auto in : n->inputs) { if (in->IsVar() && !in->IsCtrlVar()) { deps.push_back(in->Name()); } } depend_desc.SetType("depend"); depend_desc.SetInput("X", n->Op()->Inputs().at(n->Op()->InputNames()[0])); depend_desc.SetInput("Dep", deps); depend_desc.SetOutput("Out", n->Op()->Inputs().at(n->Op()->InputNames()[0])); ops->emplace_back(depend_desc); VLOG(4) << "add depend op"; } if (n->Name() == "while" || n->Name() == "while_grad" || n->Name() == "conditional_block" || n->Name() == "conditional_block_grad" || n->Name() == "recurrent" || n->Name() == "recurrent_grad") { VLOG(1) << "Update control op attr: skip_eager_deletion_vars"; UpdateControlOpSkipEagerDeletionVars(*n, graph, graph_idx, n->Name()); } ops->emplace_back(*n->Op()); VLOG(4) << n->ToString(); } // delete no OpDesc op } } template static void GetGraphVarDesc(const Graph &graph, const std::unordered_set &nodes, std::vector *vars) { for (T node : nodes) { if (node->IsVar() && node->Var() && node->GetVarNodeBlockId() == graph.GetBlockId()) { vars->emplace_back(*node->Var()->Proto()); } } } static void GraphToBlock(const Graph &graph, proto::BlockDesc *block, const SortKind *sort_kind, const size_t graph_idx) { // Remove the unneeded variables after memory optimization. std::unordered_set vars2remove; if (graph.Has(kGraphToProgramVarsToRemove)) { vars2remove = graph.Get>(kGraphToProgramVarsToRemove); VLOG(2) << "graph (id: " << block->idx() << ") to program remove " << vars2remove.size() << " nodes"; } std::vector vars_in_graph; GetGraphVarDesc(graph, graph.Nodes(), &vars_in_graph); if (graph.Has(details::kRemovedVars)) { auto &removed_vars = graph.Get(details::kRemovedVars); GetGraphVarDesc>( graph, removed_vars, &vars_in_graph); } // add vars_in_graph to blcok block->clear_vars(); std::unordered_set visited_vars; for (proto::VarDesc &var : vars_in_graph) { const std::string &var_name = var.name(); if (visited_vars.find(var_name) == visited_vars.end() && vars2remove.find(var_name) == vars2remove.end()) { block->add_vars()->MergeFrom(var); visited_vars.insert(var_name); } } block->clear_ops(); std::vector nodes; if (sort_kind != nullptr) { // Inference Memory Optimize relays on this branch. nodes = TopologyVarientSort(graph, *sort_kind); } else { if (FLAGS_convert_all_blocks) { nodes = TopologySortGraphByDescOrder(graph); } else { nodes = TopologySortOperations(graph); } } std::vector ops; GetGraphOpDesc(nodes, block, &ops, graph, graph_idx); for (auto &op : ops) { RemoveControlDepInputAndOuput(&op); block->add_ops()->MergeFrom(*op.Proto()); } } void GraphToProgram(const Graph &graph, ProgramDesc *program, const SortKind *sort_kind) { PADDLE_ENFORCE_EQ(graph.IsMainGraph(), true, platform::errors::InvalidArgument( "This graph is a sub_graph, " "and can't convert to program individually")); PADDLE_ENFORCE_NOT_NULL( program, platform::errors::InvalidArgument( "program must not be nullptr when converting graph to program")); proto::ProgramDesc program_pb(*(program->Proto())); auto block = program_pb.mutable_blocks(kRootBlockIndex); block->set_idx(kRootBlockIndex); if (FLAGS_convert_all_blocks) { GraphToBlock(*graph.GetSubGraph(kRootBlockIndex), block, sort_kind, graph.GetSubGraph(kRootBlockIndex)->GetBlockId()); VLOG(3) << "Graph to program need convert " << graph.SubGraphsSize() << " sub graph"; for (size_t idx = 0; idx < graph.SubGraphsSize(); ++idx) { // avoid kRootBlockIndex not 0 if (idx == kRootBlockIndex) continue; if (static_cast(idx) < program_pb.blocks_size()) { block = program_pb.mutable_blocks(idx); } else { block = program_pb.add_blocks(); block->set_idx(idx); block->set_parent_idx(kRootBlockIndex); } GraphToBlock(*graph.GetSubGraph(idx), block, sort_kind, graph.GetSubGraph(idx)->GetBlockId()); } } else { GraphToBlock(graph, block, sort_kind, graph.GetBlockId()); } program->CopyFrom(program_pb); if (graph.Has(details::kProgramDescs)) { details::ProgramDescs program_descs = graph.Get(details::kProgramDescs); VLOG(8) << "Merge main programs"; MergePrograms(program, program_descs, /*append=*/false); } // handle startup program } static std::vector> GetOpDependencies( const BlockDesc &block, const std::unordered_set &nodes) { auto block_ops = block.AllOps(); size_t op_num = block_ops.size(); std::unordered_map> preceding_ops(op_num); std::unordered_map preceding_deps(op_num); std::unordered_map> pending_ops(op_num); std::queue ready_ops; for (const auto *node : nodes) { if (!node->IsOp()) continue; auto &tmp_preceding_ops = preceding_ops[node]; for (const auto *in_var : node->inputs) { for (const auto *in_op : in_var->inputs) { tmp_preceding_ops.insert(in_op); } } if (tmp_preceding_ops.empty()) { ready_ops.push(node); } preceding_deps[node] = tmp_preceding_ops.size(); auto &tmp_pending_ops = pending_ops[node]; for (const auto *out_var : node->outputs) { for (const auto *out_op : out_var->outputs) { tmp_pending_ops.insert(out_op); } } } std::unordered_map> all_preceding_ops; while (!ready_ops.empty()) { const auto *cur_op = ready_ops.front(); ready_ops.pop(); auto &all_preceding_ops_of_cur_op = all_preceding_ops[cur_op]; for (const auto *preceding_op : preceding_ops.at(cur_op)) { all_preceding_ops_of_cur_op.insert(preceding_op); auto &prev_preceding_ops = all_preceding_ops[preceding_op]; all_preceding_ops_of_cur_op.insert(prev_preceding_ops.begin(), prev_preceding_ops.end()); } for (const auto *pending_op : pending_ops.at(cur_op)) { if (--preceding_deps.at(pending_op) == 0) { ready_ops.push(pending_op); } } } std::unordered_map op_id_to_idx(op_num); for (const auto *op_desc : block_ops) { size_t op_idx = op_id_to_idx.size(); PADDLE_ENFORCE_EQ( op_id_to_idx.emplace(op_desc->OriginalId(), op_idx).second, true, platform::errors::InvalidArgument( "There should not be duplicate op id: %d", op_desc->OriginalId())); } std::vector> dep_matrix(op_num); for (size_t i = 0; i < op_num; ++i) { dep_matrix[i].resize(op_num, ir::Node::Dep::kNoDep); dep_matrix[i][i] = ir::Node::Dep::kSame; } auto get_op_idx_by_id = [&op_id_to_idx](uint64_t op_id) { auto iter = op_id_to_idx.find(op_id); PADDLE_ENFORCE_NE(iter, op_id_to_idx.end(), platform::errors::InvalidArgument( "Cannot find OpDesc with id %d", op_id)); return iter->second; }; for (const auto &pair : all_preceding_ops) { const auto *cur_op_node = pair.first; size_t op_idx_1 = get_op_idx_by_id(cur_op_node->Op()->OriginalId()); for (const auto *preceding_op_node : pair.second) { size_t op_idx_2 = get_op_idx_by_id(preceding_op_node->Op()->OriginalId()); dep_matrix[op_idx_1][op_idx_2] = ir::Node::Dep::kAfter; dep_matrix[op_idx_2][op_idx_1] = ir::Node::Dep::kBefore; } } return dep_matrix; } std::vector>> GetOpDependencies( const ProgramDesc &program) { ir::Graph graph(program); size_t block_num = program.Size(); std::vector>> deps; deps.reserve(block_num); for (size_t i = 0; i < block_num; ++i) { deps.emplace_back( GetOpDependencies(program.Block(i), graph.GetSubGraph(i)->Nodes())); } return deps; } } // namespace ir } // namespace framework } // namespace paddle