diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index c42731e10f10828dae2245db5cfff3522d829924..0043229095b663077c9f628dcead16b6ab89ef9c 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -73,25 +73,30 @@ void TaskGraph::GeneratePersistenceThrdId( } } -void TaskGraph::AcyclicTopoForEachNode(std::function handler) const { +void TaskGraph::AcyclicTopoForEachNode(std::function IsAllowedStartNode, + std::function Handler) const { std::list starts; ForEachNode([&](TaskNode* node) { - if (node->in_edges().empty()) { starts.push_back(node); } + if (node->in_edges().empty() && IsAllowedStartNode(node)) { starts.push_back(node); } }); - auto ForEachInNode = [&](TaskNode* node, const std::function& handler) { + auto ForEachInNode = [&](TaskNode* node, const std::function& Handler) { node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) { if (IsBackEdge(node_on_in_edge, node)) return; - handler(const_cast(node_on_in_edge)); + Handler(const_cast(node_on_in_edge)); }); }; - auto ForEachOutNode = [&](TaskNode* node, const std::function& handler) { + auto ForEachOutNode = [&](TaskNode* node, const std::function& Handler) { node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) { if (IsBackEdge(node, node_on_out_edge)) return; - handler(const_cast(node_on_out_edge)); + Handler(const_cast(node_on_out_edge)); }); }; // DfsTopo will cause inappropriate chain graph - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, handler); + TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler); +} + +void TaskGraph::AcyclicTopoForEachNode(std::function Handler) const { + return AcyclicTopoForEachNode([](TaskNode*) { return true; }, Handler); } void TaskGraph::RemoveEmptyRegsts() { diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index fb853317e79b578e84ddabbfebcec243096f6f67..1562aed82da876c58ad786502b2cc97ce08387ac 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -53,7 +53,9 @@ class TaskGraph final : public Graph { void AddMutexCtrlEdgeInSameChain(); void AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); void RmUselessConsumeRelationshipBetweenFwBw(); - void AcyclicTopoForEachNode(std::function handler) const; + void AcyclicTopoForEachNode(std::function Handler) const; + void AcyclicTopoForEachNode(std::function IsAllowedStartNode, + std::function Handler) const; #define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS(); diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index eca7a1079772dd9d3559a97ca0e0b264e90f757d..868dabae9aea7137c34d74679cd7b06e0212363a 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -101,12 +101,10 @@ Plan Compiler::DoCompile() { task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1)); task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1)); - task_gph->AcyclicTopoForEachNode([](TaskNode* node) { - if (node->GetTaskType() != kNormalMdUpdt) { node->Build(); } - }); - task_gph->AcyclicTopoForEachNode([](TaskNode* node) { - if (node->GetTaskType() == kNormalMdUpdt) { node->Build(); } - }); + task_gph->AcyclicTopoForEachNode( + [](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; }, &TaskNode::Build); + task_gph->AcyclicTopoForEachNode( + [](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; }, &TaskNode::Build); task_gph->RemoveEmptyRegsts(); task_gph->AddOrderingCtrlEdgeInSameChain(); if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {