提交 52f3e34c 编写于 作者: S strickland12 提交者: Jinhui Yuan

refine ChainGraph (#1123)

* rm TryMerge

* rm extra loop in CollectAncestorsForEachNode


Former-commit-id: 8616e0d4
上级 faae9e38
......@@ -21,7 +21,6 @@ class ChainMerger final {
void InitTaskNode2UId();
void InitChains();
bool DoMerge(std::list<ChainIt>& chains, ChainIt rhs);
bool TryMerge();
void MergeChains();
int64_t get_task_uid(TaskNode* task_node) const {
auto uid_it = task_node2uid_.find(task_node);
......@@ -79,26 +78,19 @@ bool ChainMerger::DoMerge(std::list<ChainIt>& chains, ChainIt rhs) {
return false;
}
bool ChainMerger::TryMerge() {
void ChainMerger::MergeChains() {
HashMap<std::pair<int64_t, int64_t>, std::list<ChainIt>> stream_area2chains;
bool merge_happened = false;
for (auto cur_chain_it = chain_list_.begin(); cur_chain_it != chain_list_.end();) {
std::pair<int64_t, int64_t> stream_area_id = {cur_chain_it->stream_id, cur_chain_it->area_id};
auto stream_area_it = stream_area2chains.find(stream_area_id);
if (stream_area_it != stream_area2chains.end()
&& DoMerge(stream_area_it->second, cur_chain_it)) {
cur_chain_it = chain_list_.erase(cur_chain_it);
merge_happened = true;
} else {
stream_area2chains[stream_area_id].push_back(cur_chain_it);
++cur_chain_it;
}
}
return merge_happened;
}
void ChainMerger::MergeChains() {
while (TryMerge()) {}
}
} // namespace
......
......@@ -484,19 +484,14 @@ void TaskGraph::AddOrderCtrlEdgeBetweenCopyAndMdUpdt() {
}
void TaskGraph::CollectAncestorsForEachNode() {
std::vector<TaskNode*> ordered_nodes;
AcyclicTopoForEachNode([&](TaskNode* node) { ordered_nodes.emplace_back(node); });
for (auto it = ordered_nodes.begin(); it != ordered_nodes.end(); ++it) {
TaskNode* task_node = *it;
AcyclicTopoForEachNode([&](TaskNode* task_node) {
task_node->mut_ancestors().clear();
task_node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
if (IsBackEdge(node_on_in_edge, task_node)) return;
task_node->mut_ancestors().insert(node_on_in_edge->ancestors().begin(),
node_on_in_edge->ancestors().end());
task_node->mut_ancestors().insert(node_on_in_edge);
task_node->ForEachNodeOnInEdge([&](TaskNode* in_node) {
if (IsBackEdge(in_node, task_node)) return;
task_node->mut_ancestors().insert(in_node->ancestors().begin(), in_node->ancestors().end());
task_node->mut_ancestors().insert(in_node);
});
}
});
}
void TaskGraph::AcyclicTopoForEachNode(std::function<void(TaskNode* node)> handler) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册