提交 cbf36fb9 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

mdupdt delayed topo (#1227)



Former-commit-id: 317267a0
上级 f6de1c9a
......@@ -73,12 +73,21 @@ void TaskGraph::GeneratePersistenceThrdId(
}
}
void TaskGraph::MdUpdtDelayedTopoForEachNode(std::function<void(TaskNode* node)> Handler) const {
HashSet<const TaskNode*> built_nodes;
auto Build = [&](TaskNode* node) {
CHECK(built_nodes.emplace(node).second);
Handler(node);
};
AcyclicTopoForEachNode([](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; },
Build);
AcyclicTopoForEachNode([](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; },
Build);
ForEachNode([&](TaskNode* node) { CHECK(built_nodes.find(node) != built_nodes.end()); });
}
void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
std::function<void(TaskNode* node)> Handler) const {
std::list<TaskNode*> starts;
ForEachNode([&](TaskNode* node) {
if (node->in_edges().empty() && IsAllowedStartNode(node)) { starts.push_back(node); }
});
auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
if (IsBackEdge(node_on_in_edge, node)) return;
......@@ -91,6 +100,15 @@ void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAll
Handler(const_cast<TaskNode*>(node_on_out_edge));
});
};
auto IsSourceNode = [&](TaskNode* node) {
int32_t in_node_num = 0;
ForEachInNode(node, [&](TaskNode* in_node) { ++in_node_num; });
return in_node_num == 0;
};
std::list<TaskNode*> starts;
ForEachNode([&](TaskNode* node) {
if (IsSourceNode(node) && IsAllowedStartNode(node)) { starts.push_back(node); }
});
// DfsTopo will cause inappropriate chain graph
TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler);
}
......
......@@ -57,8 +57,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void AddOrderCtrlEdgeBetweenCopyAndMdUpdt();
void RmUselessConsumeRelationshipBetweenFwBw();
void AcyclicTopoForEachNode(std::function<void(TaskNode* node)> Handler) const;
void AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
std::function<void(TaskNode* node)> Handler) const;
void MdUpdtDelayedTopoForEachNode(std::function<void(TaskNode* node)> Handler) const;
#define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS();
......@@ -72,6 +71,8 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceGather2ReduceGather);
private:
void AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
std::function<void(TaskNode* node)> Handler) const;
void BuildTaskPath(
CompTaskNode* src, CompTaskNode* dst,
std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
......
......@@ -41,6 +41,7 @@ void ToDotFile(const Plan& plan, const std::string& filepath) {
}
out_stream << "}\n";
}
} // namespace
Plan Compiler::Compile() {
......@@ -101,10 +102,7 @@ Plan Compiler::DoCompile() {
task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));
task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));
task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
task_gph->AcyclicTopoForEachNode(
[](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; }, &TaskNode::Build);
task_gph->AcyclicTopoForEachNode(
[](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; }, &TaskNode::Build);
task_gph->MdUpdtDelayedTopoForEachNode(&TaskNode::Build);
task_gph->RemoveEmptyRegsts();
task_gph->AddOrderingCtrlEdgeInSameChain();
if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册