diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 4477643b4e205b849f01994fa934e4caf0d64157..8d0b83c41d81c64481b523ea62ad2df1fd5785eb 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -73,12 +73,21 @@ void TaskGraph::GeneratePersistenceThrdId( } } +void TaskGraph::MdUpdtDelayedTopoForEachNode(std::function Handler) const { + HashSet built_nodes; + auto Build = [&](TaskNode* node) { + CHECK(built_nodes.emplace(node).second); + Handler(node); + }; + AcyclicTopoForEachNode([](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; }, + Build); + AcyclicTopoForEachNode([](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; }, + Build); + ForEachNode([&](TaskNode* node) { CHECK(built_nodes.find(node) != built_nodes.end()); }); +} + void TaskGraph::AcyclicTopoForEachNode(std::function IsAllowedStartNode, std::function Handler) const { - std::list starts; - ForEachNode([&](TaskNode* node) { - if (node->in_edges().empty() && IsAllowedStartNode(node)) { starts.push_back(node); } - }); auto ForEachInNode = [&](TaskNode* node, const std::function& Handler) { node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) { if (IsBackEdge(node_on_in_edge, node)) return; @@ -91,6 +100,15 @@ void TaskGraph::AcyclicTopoForEachNode(std::function IsAll Handler(const_cast(node_on_out_edge)); }); }; + auto IsSourceNode = [&](TaskNode* node) { + int32_t in_node_num = 0; + ForEachInNode(node, [&](TaskNode* in_node) { ++in_node_num; }); + return in_node_num == 0; + }; + std::list starts; + ForEachNode([&](TaskNode* node) { + if (IsSourceNode(node) && IsAllowedStartNode(node)) { starts.push_back(node); } + }); // DfsTopo will cause inappropriate chain graph TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler); } diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index 42245d92a41e86ffd26ec2ef039bae2a0da2ea26..2313718b98c0cc951bb32386c46ea3cdcc33a61d 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -57,8 +57,7 @@ class TaskGraph final : public Graph { void AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); void RmUselessConsumeRelationshipBetweenFwBw(); void AcyclicTopoForEachNode(std::function Handler) const; - void AcyclicTopoForEachNode(std::function IsAllowedStartNode, - std::function Handler) const; + void MdUpdtDelayedTopoForEachNode(std::function Handler) const; #define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS(); @@ -72,6 +71,8 @@ class TaskGraph final : public Graph { DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceGather2ReduceGather); private: + void AcyclicTopoForEachNode(std::function IsAllowedStartNode, + std::function Handler) const; void BuildTaskPath( CompTaskNode* src, CompTaskNode* dst, std::function diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 868dabae9aea7137c34d74679cd7b06e0212363a..dff291eefc72a5c5577bbab204b544e950f9a7be 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -41,6 +41,7 @@ void ToDotFile(const Plan& plan, const std::string& filepath) { } out_stream << "}\n"; } + } // namespace Plan Compiler::Compile() { @@ -101,10 +102,7 @@ Plan Compiler::DoCompile() { task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1)); task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1)); - task_gph->AcyclicTopoForEachNode( - [](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; }, &TaskNode::Build); - task_gph->AcyclicTopoForEachNode( - [](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; }, &TaskNode::Build); + task_gph->MdUpdtDelayedTopoForEachNode(&TaskNode::Build); task_gph->RemoveEmptyRegsts(); task_gph->AddOrderingCtrlEdgeInSameChain(); if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {