提交 6c7fb61c 编写于 作者: J Jinhui Yuan 提交者: GitHub

fix bugs in prediction mode (#1194)



Former-commit-id: 2ebe0205
上级 b84b880c
......@@ -76,7 +76,7 @@ void TaskGraph::GeneratePersistenceThrdId(
void TaskGraph::AcyclicTopoForEachNode(std::function<void(TaskNode* node)> handler) const {
std::list<TaskNode*> starts;
ForEachNode([&](TaskNode* node) {
if (node->consumed_regsts().empty() && !node->IsMeaningLess()) { starts.push_back(node); }
if (node->in_edges().empty()) { starts.push_back(node); }
});
auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& handler) {
node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
......
......@@ -81,7 +81,7 @@ void TaskNode::PinConsumedRegst() {
}
void TaskNode::Build() {
CHECK(IsReadyForBuild());
if (consumed_regsts_.size()) { CHECK(IsReadyForBuild()); }
BuildExecGphAndRegst();
LockRegsts();
FixRegisterNumRange();
......@@ -98,7 +98,7 @@ void TaskNode::EraseZeroSizeConsumedRegst() {
auto regst_ptr = *it;
CHECK(regst_ptr);
if (regst_ptr->regst_desc_type().has_data_regst_desc() && regst_ptr->NumOfLbi() == 0) {
pair.second.erase(it++);
it = pair.second.erase(it);
} else {
++it;
}
......
......@@ -101,8 +101,16 @@ Plan Compiler::DoCompile() {
task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));
task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));
task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
task_gph->AcyclicTopoForEachNode(
[](TaskNode* node) { node->Build(); }); // kMdUpdt task will not be built in Prediction mode
if (job_desc->IsTrain()) {
task_gph->AcyclicTopoForEachNode([](TaskNode* node) { node->Build(); });
} else {
task_gph->AcyclicTopoForEachNode([](TaskNode* node) {
if (node->GetTaskType() != kNormalMdUpdt) { node->Build(); }
});
task_gph->AcyclicTopoForEachNode([](TaskNode* node) {
if (node->GetTaskType() == kNormalMdUpdt) { node->Build(); }
});
}
task_gph->RemoveEmptyRegsts();
task_gph->AddOrderingCtrlEdgeInSameChain();
if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册