diff --git a/mindspore/ccsrc/optimizer/cse.cc b/mindspore/ccsrc/optimizer/cse.cc index 82050f61088d7b228619a5afa45295d9d810b782..42ebf5a658e64bdfa8efa1cac537ee901dd1d1dc 100644 --- a/mindspore/ccsrc/optimizer/cse.cc +++ b/mindspore/ccsrc/optimizer/cse.cc @@ -40,14 +40,14 @@ BasePtr AbsOf(const AnfNodePtr &node) { return node_abs; } -namespace { -void BuildOrderGroup(const FuncGraphManagerPtr manager, std::vector *const order_group, - std::unordered_map> *groups) { - MS_EXCEPTION_IF_NULL(order_group); - - std::unordered_map hashes; +bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { + bool changed = false; for (FuncGraphPtr fg : manager->func_graphs()) { MS_EXCEPTION_IF_NULL(fg); + std::vector order_group; + std::unordered_map> groups; + std::unordered_map hashes; + std::vector toposet = TopoSort(fg->get_return()); for (auto node : toposet) { MS_EXCEPTION_IF_NULL(node); @@ -75,17 +75,20 @@ void BuildOrderGroup(const FuncGraphManagerPtr manager, std::vector } hashes[node] = h; - if (groups->find(h) == groups->end()) { + if (groups.find(h) == groups.end()) { std::vector innervec({node}); - (*groups)[h] = innervec; - order_group->emplace_back(h); + groups[h] = innervec; + order_group.emplace_back(h); } else { - (*groups)[h].push_back(node); + groups[h].push_back(node); } } + + changed = DoReplace(manager, order_group, &groups) || changed; } + + return changed; } -} // namespace bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(main); @@ -177,10 +180,7 @@ bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); - std::unordered_map> groups; - std::vector order_group; - BuildOrderGroup(manager, &order_group, &groups); - return DoReplace(manager, order_group, &groups); + return BuildOrderGroupAndDoReplace(manager); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/cse.h b/mindspore/ccsrc/optimizer/cse.h index 823b24edb7a6c0d91311e455864023a4766cc950..544e6cb6a36f7c14dc99b781d1258522ea043c68 100644 --- a/mindspore/ccsrc/optimizer/cse.h +++ b/mindspore/ccsrc/optimizer/cse.h @@ -46,6 +46,7 @@ class CSE { bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const; private: + bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const; bool DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, std::unordered_map> *groups) const; bool report_changes_;