未验证 提交 a8b7dedb 编写于 作者: J Jinhui Yuan 提交者: GitHub

Build task node in topological order (#1162)

上级 e585bba0
......@@ -17,8 +17,6 @@ class Graph {
// For Each
void ForEachNode(std::function<void(NodeType*)> NodeHandler) const;
void ForEachNode(std::function<void(NodeType*)> NodeHandler,
std::function<bool(NodeType*)> IsNodeReady) const;
void TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
void ReverseTopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
void ForEachEdge(std::function<void(EdgeType*)> EdgeHandler) const;
......@@ -80,30 +78,6 @@ void Graph<NodeType, EdgeType>::ForEachNode(std::function<void(NodeType*)> NodeH
for (auto& x : nodes_) { NodeHandler(x.get()); }
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ForEachNode(std::function<void(NodeType*)> NodeHandler,
std::function<bool(NodeType*)> IsNodeReady) const {
std::queue<NodeType*> node_queue;
HashSet<NodeType*> nodes_pushed;
for (auto& x : nodes_) {
if (IsNodeReady(x.get())) {
node_queue.push(x.get());
CHECK(nodes_pushed.insert(x.get()).second);
}
}
while (node_queue.empty() == false) {
NodeType* cur_node = node_queue.front();
node_queue.pop();
NodeHandler(cur_node);
cur_node->ForEachNodeOnInOutEdge([&](NodeType* candidate) {
if (nodes_pushed.find(candidate) == nodes_pushed.end() && IsNodeReady(candidate)) {
node_queue.push(candidate);
CHECK(nodes_pushed.insert(candidate).second);
}
});
}
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const {
std::list<NodeType*> starts;
......
......@@ -81,6 +81,7 @@ void TaskNode::PinConsumedRegst() {
}
void TaskNode::Build() {
CHECK(IsReadyForBuild());
BuildExecGphAndRegst();
LockRegsts();
FixRegisterNumRange();
......
......@@ -60,7 +60,8 @@ Plan Compiler::DoCompile() {
task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));
task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));
task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
task_gph->ForEachNode(std::bind(&TaskNode::Build, _1), std::bind(&TaskNode::IsReadyForBuild, _1));
task_gph->AcyclicTopoForEachNode(
[](TaskNode* node) { node->Build(); }); // kMdUpdt task will not be built in Prediction mode
task_gph->RemoveEmptyRegsts();
task_gph->AddOrderingCtrlEdgeInSameChain();
if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册