提交 c8035696 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!181 Do cse graph by graph

Merge pull request !181 from lyfne/cse_fix
......@@ -40,14 +40,14 @@ BasePtr AbsOf(const AnfNodePtr &node) {
return node_abs;
}
namespace {
void BuildOrderGroup(const FuncGraphManagerPtr manager, std::vector<std::size_t> *const order_group,
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) {
MS_EXCEPTION_IF_NULL(order_group);
std::unordered_map<AnfNodePtr, std::size_t> hashes;
bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
bool changed = false;
for (FuncGraphPtr fg : manager->func_graphs()) {
MS_EXCEPTION_IF_NULL(fg);
std::vector<std::size_t> order_group;
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups;
std::unordered_map<AnfNodePtr, std::size_t> hashes;
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
for (auto node : toposet) {
MS_EXCEPTION_IF_NULL(node);
......@@ -75,17 +75,20 @@ void BuildOrderGroup(const FuncGraphManagerPtr manager, std::vector<std::size_t>
}
hashes[node] = h;
if (groups->find(h) == groups->end()) {
if (groups.find(h) == groups.end()) {
std::vector<AnfNodePtr> innervec({node});
(*groups)[h] = innervec;
order_group->emplace_back(h);
groups[h] = innervec;
order_group.emplace_back(h);
} else {
(*groups)[h].push_back(node);
groups[h].push_back(node);
}
}
changed = DoReplace(manager, order_group, &groups) || changed;
}
return changed;
}
} // namespace
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(main);
......@@ -177,10 +180,7 @@ bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root);
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups;
std::vector<std::size_t> order_group;
BuildOrderGroup(manager, &order_group, &groups);
return DoReplace(manager, order_group, &groups);
return BuildOrderGroupAndDoReplace(manager);
}
} // namespace opt
} // namespace mindspore
......@@ -46,6 +46,7 @@ class CSE {
bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;
private:
bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const;
bool DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group,
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) const;
bool report_changes_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册