From cbf36fb9d14e64ec62d0ed149d448c79c3e1d5b2 Mon Sep 17 00:00:00 2001 From: Li Xinqi Date: Thu, 13 Sep 2018 14:08:09 +0800 Subject: [PATCH] mdupdt delayed topo (#1227) Former-commit-id: 317267a0a9b2129f42222d42b9e09ebf70daa1a7 --- oneflow/core/graph/task_graph.cpp | 26 ++++++++++++++++++++++---- oneflow/core/graph/task_graph.h | 5 +++-- oneflow/core/job/compiler.cpp | 6 ++---- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 4477643b4e..8d0b83c41d 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -73,12 +73,21 @@ void TaskGraph::GeneratePersistenceThrdId( } } +void TaskGraph::MdUpdtDelayedTopoForEachNode(std::function Handler) const { + HashSet built_nodes; + auto Build = [&](TaskNode* node) { + CHECK(built_nodes.emplace(node).second); + Handler(node); + }; + AcyclicTopoForEachNode([](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; }, + Build); + AcyclicTopoForEachNode([](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; }, + Build); + ForEachNode([&](TaskNode* node) { CHECK(built_nodes.find(node) != built_nodes.end()); }); +} + void TaskGraph::AcyclicTopoForEachNode(std::function IsAllowedStartNode, std::function Handler) const { - std::list starts; - ForEachNode([&](TaskNode* node) { - if (node->in_edges().empty() && IsAllowedStartNode(node)) { starts.push_back(node); } - }); auto ForEachInNode = [&](TaskNode* node, const std::function& Handler) { node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) { if (IsBackEdge(node_on_in_edge, node)) return; @@ -91,6 +100,15 @@ void TaskGraph::AcyclicTopoForEachNode(std::function IsAll Handler(const_cast(node_on_out_edge)); }); }; + auto IsSourceNode = [&](TaskNode* node) { + int32_t in_node_num = 0; + ForEachInNode(node, [&](TaskNode* in_node) { ++in_node_num; }); + return in_node_num == 0; + }; + std::list starts; + ForEachNode([&](TaskNode* node) { + if (IsSourceNode(node) && IsAllowedStartNode(node)) { starts.push_back(node); } + }); // DfsTopo will cause inappropriate chain graph TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler); } diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index 42245d92a4..2313718b98 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -57,8 +57,7 @@ class TaskGraph final : public Graph { void AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); void RmUselessConsumeRelationshipBetweenFwBw(); void AcyclicTopoForEachNode(std::function Handler) const; - void AcyclicTopoForEachNode(std::function IsAllowedStartNode, - std::function Handler) const; + void MdUpdtDelayedTopoForEachNode(std::function Handler) const; #define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS(); @@ -72,6 +71,8 @@ class TaskGraph final : public Graph { DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceGather2ReduceGather); private: + void AcyclicTopoForEachNode(std::function IsAllowedStartNode, + std::function Handler) const; void BuildTaskPath( CompTaskNode* src, CompTaskNode* dst, std::function diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 868dabae9a..dff291eefc 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -41,6 +41,7 @@ void ToDotFile(const Plan& plan, const std::string& filepath) { } out_stream << "}\n"; } + } // namespace Plan Compiler::Compile() { @@ -101,10 +102,7 @@ Plan Compiler::DoCompile() { task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1)); task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1)); - task_gph->AcyclicTopoForEachNode( - [](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; }, &TaskNode::Build); - task_gph->AcyclicTopoForEachNode( - [](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; }, &TaskNode::Build); + task_gph->MdUpdtDelayedTopoForEachNode(&TaskNode::Build); task_gph->RemoveEmptyRegsts(); task_gph->AddOrderingCtrlEdgeInSameChain(); if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) { -- GitLab