diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index f4a47248b15c7ff3624e9378594f5972abe0d734..eca7a1079772dd9d3559a97ca0e0b264e90f757d 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -101,16 +101,12 @@ Plan Compiler::DoCompile() { task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1)); task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1)); - if (job_desc->IsTrain()) { - task_gph->AcyclicTopoForEachNode([](TaskNode* node) { node->Build(); }); - } else { - task_gph->AcyclicTopoForEachNode([](TaskNode* node) { - if (node->GetTaskType() != kNormalMdUpdt) { node->Build(); } - }); - task_gph->AcyclicTopoForEachNode([](TaskNode* node) { - if (node->GetTaskType() == kNormalMdUpdt) { node->Build(); } - }); - } + task_gph->AcyclicTopoForEachNode([](TaskNode* node) { + if (node->GetTaskType() != kNormalMdUpdt) { node->Build(); } + }); + task_gph->AcyclicTopoForEachNode([](TaskNode* node) { + if (node->GetTaskType() == kNormalMdUpdt) { node->Build(); } + }); task_gph->RemoveEmptyRegsts(); task_gph->AddOrderingCtrlEdgeInSameChain(); if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {