From 6c7fb61c2821cb692e3f3e03dfb5b3811f9ce6a8 Mon Sep 17 00:00:00 2001 From: Jinhui Yuan Date: Sun, 2 Sep 2018 15:40:54 +0800 Subject: [PATCH] fix bugs in prediction mode (#1194) Former-commit-id: 2ebe0205b50239d141656150b1bcde3910c966a3 --- oneflow/core/graph/task_graph.cpp | 2 +- oneflow/core/graph/task_node.cpp | 4 ++-- oneflow/core/job/compiler.cpp | 12 ++++++++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 9b5d460d16..c42731e10f 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -76,7 +76,7 @@ void TaskGraph::GeneratePersistenceThrdId( void TaskGraph::AcyclicTopoForEachNode(std::function handler) const { std::list starts; ForEachNode([&](TaskNode* node) { - if (node->consumed_regsts().empty() && !node->IsMeaningLess()) { starts.push_back(node); } + if (node->in_edges().empty()) { starts.push_back(node); } }); auto ForEachInNode = [&](TaskNode* node, const std::function& handler) { node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) { diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 82498ec807..613c796cfe 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -81,7 +81,7 @@ void TaskNode::PinConsumedRegst() { } void TaskNode::Build() { - CHECK(IsReadyForBuild()); + if (consumed_regsts_.size()) { CHECK(IsReadyForBuild()); } BuildExecGphAndRegst(); LockRegsts(); FixRegisterNumRange(); @@ -98,7 +98,7 @@ void TaskNode::EraseZeroSizeConsumedRegst() { auto regst_ptr = *it; CHECK(regst_ptr); if (regst_ptr->regst_desc_type().has_data_regst_desc() && regst_ptr->NumOfLbi() == 0) { - pair.second.erase(it++); + it = pair.second.erase(it); } else { ++it; } diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 6d775d12a9..f4a47248b1 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -101,8 +101,16 @@ Plan Compiler::DoCompile() { task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1)); task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1)); - task_gph->AcyclicTopoForEachNode( - [](TaskNode* node) { node->Build(); }); // kMdUpdt task will not be built in Prediction mode + if (job_desc->IsTrain()) { + task_gph->AcyclicTopoForEachNode([](TaskNode* node) { node->Build(); }); + } else { + task_gph->AcyclicTopoForEachNode([](TaskNode* node) { + if (node->GetTaskType() != kNormalMdUpdt) { node->Build(); } + }); + task_gph->AcyclicTopoForEachNode([](TaskNode* node) { + if (node->GetTaskType() == kNormalMdUpdt) { node->Build(); } + }); + } task_gph->RemoveEmptyRegsts(); task_gph->AddOrderingCtrlEdgeInSameChain(); if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) { -- GitLab