diff --git a/oneflow/core/graph/chain_graph.cpp b/oneflow/core/graph/chain_graph.cpp index 9262ffbb63db879d083beced8287ced98e0c322b..1f387c06c6be7bab885af9708d1cecfa31c09179 100644 --- a/oneflow/core/graph/chain_graph.cpp +++ b/oneflow/core/graph/chain_graph.cpp @@ -21,7 +21,6 @@ class ChainMerger final { void InitTaskNode2UId(); void InitChains(); bool DoMerge(std::list& chains, ChainIt rhs); - bool TryMerge(); void MergeChains(); int64_t get_task_uid(TaskNode* task_node) const { auto uid_it = task_node2uid_.find(task_node); @@ -79,26 +78,19 @@ bool ChainMerger::DoMerge(std::list& chains, ChainIt rhs) { return false; } -bool ChainMerger::TryMerge() { +void ChainMerger::MergeChains() { HashMap, std::list> stream_area2chains; - bool merge_happened = false; for (auto cur_chain_it = chain_list_.begin(); cur_chain_it != chain_list_.end();) { std::pair stream_area_id = {cur_chain_it->stream_id, cur_chain_it->area_id}; auto stream_area_it = stream_area2chains.find(stream_area_id); if (stream_area_it != stream_area2chains.end() && DoMerge(stream_area_it->second, cur_chain_it)) { cur_chain_it = chain_list_.erase(cur_chain_it); - merge_happened = true; } else { stream_area2chains[stream_area_id].push_back(cur_chain_it); ++cur_chain_it; } } - return merge_happened; -} - -void ChainMerger::MergeChains() { - while (TryMerge()) {} } } // namespace diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index dc75e2c313b0ae7f5266c3adf358262d9645191a..2d48617303f5c993115c34276f47bbc85bf7391a 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -484,19 +484,14 @@ void TaskGraph::AddOrderCtrlEdgeBetweenCopyAndMdUpdt() { } void TaskGraph::CollectAncestorsForEachNode() { - std::vector ordered_nodes; - AcyclicTopoForEachNode([&](TaskNode* node) { ordered_nodes.emplace_back(node); }); - for (auto it = ordered_nodes.begin(); it != ordered_nodes.end(); ++it) { - TaskNode* task_node = *it; + AcyclicTopoForEachNode([&](TaskNode* task_node) { task_node->mut_ancestors().clear(); - task_node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) { - if (IsBackEdge(node_on_in_edge, task_node)) return; - task_node->mut_ancestors().insert(node_on_in_edge->ancestors().begin(), - node_on_in_edge->ancestors().end()); - task_node->mut_ancestors().insert(node_on_in_edge); - + task_node->ForEachNodeOnInEdge([&](TaskNode* in_node) { + if (IsBackEdge(in_node, task_node)) return; + task_node->mut_ancestors().insert(in_node->ancestors().begin(), in_node->ancestors().end()); + task_node->mut_ancestors().insert(in_node); }); - } + }); } void TaskGraph::AcyclicTopoForEachNode(std::function handler) const {