提交 737bfc95 编写于 作者: Z Zhang Qinghua

Use FuncGraph generation number replacing set.

上级 f31564ce
...@@ -47,6 +47,7 @@ FuncGraph::FuncGraph() ...@@ -47,6 +47,7 @@ FuncGraph::FuncGraph()
: flags_(), : flags_(),
transforms_(), transforms_(),
parameter_default_value_(), parameter_default_value_(),
seen_(0),
parameters_(), parameters_(),
has_vararg_(false), has_vararg_(false),
has_kwarg_(false), has_kwarg_(false),
...@@ -981,6 +982,11 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) { ...@@ -981,6 +982,11 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
} }
} }
size_t NewFgSeenGeneration() {
static size_t fg_seen_generation = 0;
return ++fg_seen_generation;
}
const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph"); const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
const char kFuncGraphFlagUndetermined[] = "Undeterminate"; const char kFuncGraphFlagUndetermined[] = "Undeterminate";
} // namespace mindspore } // namespace mindspore
...@@ -289,6 +289,7 @@ class FuncGraph : public FuncGraphBase { ...@@ -289,6 +289,7 @@ class FuncGraph : public FuncGraphBase {
// parameter default value // parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_; std::map<std::string, AnfNodePtr> parameter_default_value_;
std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_; std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
size_t seen_;
std::list<CNodePtr> GetOrderedCnodes(); std::list<CNodePtr> GetOrderedCnodes();
void EraseUnusedNodeInOrder(const AnfNodePtr &n); void EraseUnusedNodeInOrder(const AnfNodePtr &n);
...@@ -364,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP ...@@ -364,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP
return fg->NewCNode(inputs); return fg->NewCNode(inputs);
} }
size_t NewFgSeenGeneration();
// Find the root cnodes of a segment of cnodes. // Find the root cnodes of a segment of cnodes.
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment); std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
// Find the leaf cnodes of a segment of cnodes. // Find the leaf cnodes of a segment of cnodes.
......
...@@ -755,8 +755,8 @@ void DepComputer::Recompute(const FuncGraphPtr &fg) { ...@@ -755,8 +755,8 @@ void DepComputer::Recompute(const FuncGraphPtr &fg) {
} }
} }
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) {
if (path == nullptr || path->contains(fg)) { if (fg->seen_ == seen_num) {
return std::make_shared<FuncGraphSet>(); return std::make_shared<FuncGraphSet>();
} }
FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>(); FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
...@@ -770,9 +770,9 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f ...@@ -770,9 +770,9 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
// Search the fv in fg's child func graph. // Search the fv in fg's child func graph.
auto &fg_value_nodes = fg->func_graph_value_nodes(); auto &fg_value_nodes = fg->func_graph_value_nodes();
for (auto &fg_value_node : fg_value_nodes) { for (auto &fg_value_node : fg_value_nodes) {
path->add(fg); fg->seen_ = seen_num;
auto gt = GetValueNode<FuncGraphPtr>(fg_value_node.first); auto gt = GetValueNode<FuncGraphPtr>(fg_value_node.first);
parents->update(SeekParents(gt, path)); parents->update(SeekParents(gt, seen_num));
} }
(void)parents->erase(fg); (void)parents->erase(fg);
return parents; return parents;
...@@ -780,7 +780,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f ...@@ -780,7 +780,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
func_graph_parents_total_analysis_[fg].update(SeekParents(fg)); func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration()));
} }
bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
...@@ -968,9 +968,8 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F ...@@ -968,9 +968,8 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
} }
} }
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
MS_EXCEPTION_IF_NULL(path); if (fg->seen_ == seen_num) {
if (path->contains(fg)) {
MS_LOG(DEBUG) << fg->ToString() << " had been checked"; MS_LOG(DEBUG) << fg->ToString() << " had been checked";
return false; return false;
} }
...@@ -978,19 +977,20 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt ...@@ -978,19 +977,20 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt
if (!j_fg_value_nodes.empty()) { if (!j_fg_value_nodes.empty()) {
// check g1->J(fg)->g2->g cycle; // check g1->J(fg)->g2->g cycle;
auto contains_j = auto contains_j =
std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), [seen_num](const std::pair<AnfNodePtr, int> iter) {
[path](const std::pair<AnfNodePtr, int> iter) { return !path->contains(GetValueNode<FuncGraphPtr>(iter.first)); }); return GetValueNode<FuncGraphPtr>(iter.first)->seen_ != seen_num;
});
if (contains_j != j_fg_value_nodes.end()) { if (contains_j != j_fg_value_nodes.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
return true; return true;
} }
} }
path->add(fg); fg->seen_ = seen_num;
// check if func graphs used contains J(func_graph); // check if func graphs used contains J(func_graph);
for (auto &item : fg->func_graph_value_nodes()) { for (auto &item : fg->func_graph_value_nodes()) {
auto used_g = GetValueNode<FuncGraphPtr>(item.first); auto used_g = GetValueNode<FuncGraphPtr>(item.first);
if (SeekJ(used_g, path)) { if (SeekJ(used_g, seen_num)) {
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
return true; return true;
} }
...@@ -1000,7 +1000,6 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt ...@@ -1000,7 +1000,6 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt
} }
void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) { void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) {
std::shared_ptr<FuncGraphSet> path = std::make_shared<FuncGraphSet>(); this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration());
this->j_total_analysis_[fg] = SeekJ(fg, path);
} }
} // namespace mindspore } // namespace mindspore
...@@ -283,7 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer { ...@@ -283,7 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer {
void RealRecompute(FuncGraphPtr fg) override; void RealRecompute(FuncGraphPtr fg) override;
private: private:
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>()); FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, size_t seen_num);
}; };
using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
...@@ -423,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer { ...@@ -423,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer {
void ExtraReset() override { j_total_analysis_.clear(); } void ExtraReset() override { j_total_analysis_.clear(); }
void RealRecompute(FuncGraphPtr fg) override; void RealRecompute(FuncGraphPtr fg) override;
bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path); bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
}; };
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册