From a5f1e5059de393b574b33507ff728d59f5790ea4 Mon Sep 17 00:00:00 2001 From: Li Xinqi Date: Mon, 3 Sep 2018 17:27:32 +0800 Subject: [PATCH] split sources when infer shape (#1202) Former-commit-id: 34fb73fed1086c7223d25aaf519a2182deca2ab0 --- oneflow/core/graph/task_graph.cpp | 19 ++++++++++++------- oneflow/core/graph/task_graph.h | 4 +++- oneflow/core/job/compiler.cpp | 10 ++++------ 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index c42731e10f..0043229095 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -73,25 +73,30 @@ void TaskGraph::GeneratePersistenceThrdId( } } -void TaskGraph::AcyclicTopoForEachNode(std::function handler) const { +void TaskGraph::AcyclicTopoForEachNode(std::function IsAllowedStartNode, + std::function Handler) const { std::list starts; ForEachNode([&](TaskNode* node) { - if (node->in_edges().empty()) { starts.push_back(node); } + if (node->in_edges().empty() && IsAllowedStartNode(node)) { starts.push_back(node); } }); - auto ForEachInNode = [&](TaskNode* node, const std::function& handler) { + auto ForEachInNode = [&](TaskNode* node, const std::function& Handler) { node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) { if (IsBackEdge(node_on_in_edge, node)) return; - handler(const_cast(node_on_in_edge)); + Handler(const_cast(node_on_in_edge)); }); }; - auto ForEachOutNode = [&](TaskNode* node, const std::function& handler) { + auto ForEachOutNode = [&](TaskNode* node, const std::function& Handler) { node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) { if (IsBackEdge(node, node_on_out_edge)) return; - handler(const_cast(node_on_out_edge)); + Handler(const_cast(node_on_out_edge)); }); }; // DfsTopo will cause inappropriate chain graph - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, handler); + TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler); +} + +void TaskGraph::AcyclicTopoForEachNode(std::function Handler) const { + return AcyclicTopoForEachNode([](TaskNode*) { return true; }, Handler); } void TaskGraph::RemoveEmptyRegsts() { diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index fb853317e7..1562aed82d 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -53,7 +53,9 @@ class TaskGraph final : public Graph { void AddMutexCtrlEdgeInSameChain(); void AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); void RmUselessConsumeRelationshipBetweenFwBw(); - void AcyclicTopoForEachNode(std::function handler) const; + void AcyclicTopoForEachNode(std::function Handler) const; + void AcyclicTopoForEachNode(std::function IsAllowedStartNode, + std::function Handler) const; #define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS(); diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index eca7a10797..868dabae9a 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -101,12 +101,10 @@ Plan Compiler::DoCompile() { task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1)); task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1)); - task_gph->AcyclicTopoForEachNode([](TaskNode* node) { - if (node->GetTaskType() != kNormalMdUpdt) { node->Build(); } - }); - task_gph->AcyclicTopoForEachNode([](TaskNode* node) { - if (node->GetTaskType() == kNormalMdUpdt) { node->Build(); } - }); + task_gph->AcyclicTopoForEachNode( + [](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; }, &TaskNode::Build); + task_gph->AcyclicTopoForEachNode( + [](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; }, &TaskNode::Build); task_gph->RemoveEmptyRegsts(); task_gph->AddOrderingCtrlEdgeInSameChain(); if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) { -- GitLab