diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index d168ae09b56bcff740ee99953d888841dd467a45..510b4cbd24b8347eb61335725b33b087d2c55502 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -17,6 +17,8 @@ */ #include "pipeline/jit/parse/parse.h" + +#include #include #include #include @@ -1480,21 +1482,25 @@ AnfNodePtr FindPhis(const std::unordered_map &removabl void Parser::RemoveUnnecessaryPhis() { // merge all removable phis to one map; std::unordered_map removable_phis; + std::vector phis; for (FunctionBlockPtr &block : func_block_list_) { MS_EXCEPTION_IF_NULL(block); removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); + std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis), + [](std::pair pair) { return pair.first; }); } if (removable_phis.size() == 0) { return; } - auto fg_name = func_graph_->ToString(); auto mng = Manage(func_graph_, false); // replace the nodes - for (auto iter : removable_phis) { - auto new_node = FindPhis(removable_phis, iter.first); - MS_LOG(DEBUG) << "phi " << iter.first->DebugString() << " to " << new_node->DebugString(); - mng->Replace(iter.first, new_node); + // remove from inside to outside + for (int idx = SizeToInt(phis.size() - 1); idx >= 0; idx--) { + auto phi = phis[IntToSize(idx)]; + auto new_node = FindPhis(removable_phis, phi); + MS_LOG(DEBUG) << "phi " << phi->DebugString() << " to " << new_node->DebugString(); + mng->Replace(phi, new_node); } // remove the parameter for (FunctionBlockPtr &block : func_block_list_) { diff --git a/mindspore/ccsrc/utils/ordered_map.h b/mindspore/ccsrc/utils/ordered_map.h index 56ae281dcf17b93034609f2f6bdd69390fe35458..50be0971280839890c60351eec2a56321b7f1805 100644 --- a/mindspore/ccsrc/utils/ordered_map.h +++ b/mindspore/ccsrc/utils/ordered_map.h @@ -124,7 +124,7 @@ class OrderedMap { std::pair insert(const pair_type &kv) { auto result = add(kv.first); if (result.second) { - *(result.first) = kv.second; + *(result.first) = kv; return std::make_pair(std::prev(end()), true); } return std::make_pair(result.first, false);