提交 a5f1e505 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

split sources when infer shape (#1202)



Former-commit-id: 34fb73fe
上级 dd9be365
......@@ -73,25 +73,30 @@ void TaskGraph::GeneratePersistenceThrdId(
}
}
void TaskGraph::AcyclicTopoForEachNode(std::function<void(TaskNode* node)> handler) const {
void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
std::function<void(TaskNode* node)> Handler) const {
std::list<TaskNode*> starts;
ForEachNode([&](TaskNode* node) {
if (node->in_edges().empty()) { starts.push_back(node); }
if (node->in_edges().empty() && IsAllowedStartNode(node)) { starts.push_back(node); }
});
auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& handler) {
auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
if (IsBackEdge(node_on_in_edge, node)) return;
handler(const_cast<TaskNode*>(node_on_in_edge));
Handler(const_cast<TaskNode*>(node_on_in_edge));
});
};
auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& handler) {
auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
if (IsBackEdge(node, node_on_out_edge)) return;
handler(const_cast<TaskNode*>(node_on_out_edge));
Handler(const_cast<TaskNode*>(node_on_out_edge));
});
};
// DfsTopo will cause inappropriate chain graph
TopoForEachNode(starts, ForEachInNode, ForEachOutNode, handler);
TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler);
}
void TaskGraph::AcyclicTopoForEachNode(std::function<void(TaskNode* node)> Handler) const {
return AcyclicTopoForEachNode([](TaskNode*) { return true; }, Handler);
}
void TaskGraph::RemoveEmptyRegsts() {
......
......@@ -53,7 +53,9 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void AddMutexCtrlEdgeInSameChain();
void AddOrderCtrlEdgeBetweenCopyAndMdUpdt();
void RmUselessConsumeRelationshipBetweenFwBw();
void AcyclicTopoForEachNode(std::function<void(TaskNode* node)> handler) const;
void AcyclicTopoForEachNode(std::function<void(TaskNode* node)> Handler) const;
void AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
std::function<void(TaskNode* node)> Handler) const;
#define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS();
......
......@@ -101,12 +101,10 @@ Plan Compiler::DoCompile() {
task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));
task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));
task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
task_gph->AcyclicTopoForEachNode([](TaskNode* node) {
if (node->GetTaskType() != kNormalMdUpdt) { node->Build(); }
});
task_gph->AcyclicTopoForEachNode([](TaskNode* node) {
if (node->GetTaskType() == kNormalMdUpdt) { node->Build(); }
});
task_gph->AcyclicTopoForEachNode(
[](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; }, &TaskNode::Build);
task_gph->AcyclicTopoForEachNode(
[](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; }, &TaskNode::Build);
task_gph->RemoveEmptyRegsts();
task_gph->AddOrderingCtrlEdgeInSameChain();
if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册