提交 a5f1e505 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

split sources when infer shape (#1202)



Former-commit-id: 34fb73fe
上级 dd9be365
...@@ -73,25 +73,30 @@ void TaskGraph::GeneratePersistenceThrdId( ...@@ -73,25 +73,30 @@ void TaskGraph::GeneratePersistenceThrdId(
} }
} }
void TaskGraph::AcyclicTopoForEachNode(std::function<void(TaskNode* node)> handler) const { void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
std::function<void(TaskNode* node)> Handler) const {
std::list<TaskNode*> starts; std::list<TaskNode*> starts;
ForEachNode([&](TaskNode* node) { ForEachNode([&](TaskNode* node) {
if (node->in_edges().empty()) { starts.push_back(node); } if (node->in_edges().empty() && IsAllowedStartNode(node)) { starts.push_back(node); }
}); });
auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& handler) { auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) { node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
if (IsBackEdge(node_on_in_edge, node)) return; if (IsBackEdge(node_on_in_edge, node)) return;
handler(const_cast<TaskNode*>(node_on_in_edge)); Handler(const_cast<TaskNode*>(node_on_in_edge));
}); });
}; };
auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& handler) { auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) { node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
if (IsBackEdge(node, node_on_out_edge)) return; if (IsBackEdge(node, node_on_out_edge)) return;
handler(const_cast<TaskNode*>(node_on_out_edge)); Handler(const_cast<TaskNode*>(node_on_out_edge));
}); });
}; };
// DfsTopo will cause inappropriate chain graph // DfsTopo will cause inappropriate chain graph
TopoForEachNode(starts, ForEachInNode, ForEachOutNode, handler); TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler);
}
void TaskGraph::AcyclicTopoForEachNode(std::function<void(TaskNode* node)> Handler) const {
return AcyclicTopoForEachNode([](TaskNode*) { return true; }, Handler);
} }
void TaskGraph::RemoveEmptyRegsts() { void TaskGraph::RemoveEmptyRegsts() {
......
...@@ -53,7 +53,9 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> { ...@@ -53,7 +53,9 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void AddMutexCtrlEdgeInSameChain(); void AddMutexCtrlEdgeInSameChain();
void AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); void AddOrderCtrlEdgeBetweenCopyAndMdUpdt();
void RmUselessConsumeRelationshipBetweenFwBw(); void RmUselessConsumeRelationshipBetweenFwBw();
void AcyclicTopoForEachNode(std::function<void(TaskNode* node)> handler) const; void AcyclicTopoForEachNode(std::function<void(TaskNode* node)> Handler) const;
void AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
std::function<void(TaskNode* node)> Handler) const;
#define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS(); #define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS();
......
...@@ -101,12 +101,10 @@ Plan Compiler::DoCompile() { ...@@ -101,12 +101,10 @@ Plan Compiler::DoCompile() {
task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));
task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1)); task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));
task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1)); task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
task_gph->AcyclicTopoForEachNode([](TaskNode* node) { task_gph->AcyclicTopoForEachNode(
if (node->GetTaskType() != kNormalMdUpdt) { node->Build(); } [](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; }, &TaskNode::Build);
}); task_gph->AcyclicTopoForEachNode(
task_gph->AcyclicTopoForEachNode([](TaskNode* node) { [](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; }, &TaskNode::Build);
if (node->GetTaskType() == kNormalMdUpdt) { node->Build(); }
});
task_gph->RemoveEmptyRegsts(); task_gph->RemoveEmptyRegsts();
task_gph->AddOrderingCtrlEdgeInSameChain(); task_gph->AddOrderingCtrlEdgeInSameChain();
if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) { if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册