diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 4dc84cfa0191179a270b0ee4ec167054252ed7ec..4833e3838b7856471479726463df85abe9ee7ec6 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -47,6 +47,7 @@ FuncGraph::FuncGraph() : flags_(), transforms_(), parameter_default_value_(), + seen_(0), parameters_(), has_vararg_(false), has_kwarg_(false), @@ -981,6 +982,11 @@ void FuncGraph::SetEffectDepends(const std::vector &depend_inputs) { } } +size_t NewFgSeenGeneration() { + static size_t fg_seen_generation = 0; + return ++fg_seen_generation; +} + const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared("FuncGraph"); const char kFuncGraphFlagUndetermined[] = "Undeterminate"; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index 02a18a98092de6588c51c4586a7e00576ffc907d..f4c9d7079f30f168eae1c0a1fc50f8a8da5884d8 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -289,6 +289,7 @@ class FuncGraph : public FuncGraphBase { // parameter default value std::map parameter_default_value_; std::unordered_map make_ref_params_; + size_t seen_; std::list GetOrderedCnodes(); void EraseUnusedNodeInOrder(const AnfNodePtr &n); @@ -364,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector &inputs, const FuncGraphP return fg->NewCNode(inputs); } +size_t NewFgSeenGeneration(); + // Find the root cnodes of a segment of cnodes. std::shared_ptr> FindRoots(const std::vector &segment); // Find the leaf cnodes of a segment of cnodes. diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index 00b31d543ed189bca0f186fe95f26696955c7c89..cfaa84a05bdf2d742118991d6958e90d63e1ecd9 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -755,8 +755,8 @@ void DepComputer::Recompute(const FuncGraphPtr &fg) { } } -FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { - if (path == nullptr || path->contains(fg)) { +FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) { + if (fg->seen_ == seen_num) { return std::make_shared(); } FuncGraphSetPtr parents = std::make_shared(); @@ -770,9 +770,9 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f // Search the fv in fg's child func graph. auto &fg_value_nodes = fg->func_graph_value_nodes(); for (auto &fg_value_node : fg_value_nodes) { - path->add(fg); + fg->seen_ = seen_num; auto gt = GetValueNode(fg_value_node.first); - parents->update(SeekParents(gt, path)); + parents->update(SeekParents(gt, seen_num)); } (void)parents->erase(fg); return parents; @@ -780,7 +780,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(fg); - func_graph_parents_total_analysis_[fg].update(SeekParents(fg)); + func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration())); } bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { @@ -968,9 +968,8 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::listcontains(fg)) { +bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { + if (fg->seen_ == seen_num) { MS_LOG(DEBUG) << fg->ToString() << " had been checked"; return false; } @@ -978,19 +977,20 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt if (!j_fg_value_nodes.empty()) { // check g1->J(fg)->g2->g cycle; auto contains_j = - std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), - [path](const std::pair iter) { return !path->contains(GetValueNode(iter.first)); }); + std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), [seen_num](const std::pair iter) { + return GetValueNode(iter.first)->seen_ != seen_num; + }); if (contains_j != j_fg_value_nodes.end()) { MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; return true; } } - path->add(fg); + fg->seen_ = seen_num; // check if func graphs used contains J(func_graph); for (auto &item : fg->func_graph_value_nodes()) { auto used_g = GetValueNode(item.first); - if (SeekJ(used_g, path)) { + if (SeekJ(used_g, seen_num)) { MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; return true; } @@ -1000,7 +1000,6 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt } void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) { - std::shared_ptr path = std::make_shared(); - this->j_total_analysis_[fg] = SeekJ(fg, path); + this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration()); } } // namespace mindspore diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index 06b2859feace0a0d2613af05b37337e1c17a2ec6..d748a08593f03881284b0ed9a74dc94aa14c428a 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -283,7 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer { void RealRecompute(FuncGraphPtr fg) override; private: - FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared()); + FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, size_t seen_num); }; using FuncGraphToFuncGraphMap = OrderedMap; @@ -423,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer { void ExtraReset() override { j_total_analysis_.clear(); } void RealRecompute(FuncGraphPtr fg) override; - bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path); + bool SeekJ(const FuncGraphPtr &fg, size_t seen_num); }; class FuncGraphManager : public std::enable_shared_from_this {