提交 737bfc95 编写于 作者: Z Zhang Qinghua

Use FuncGraph generation number replacing set.

上级 f31564ce
......@@ -47,6 +47,7 @@ FuncGraph::FuncGraph()
: flags_(),
transforms_(),
parameter_default_value_(),
seen_(0),
parameters_(),
has_vararg_(false),
has_kwarg_(false),
......@@ -981,6 +982,11 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
}
}
size_t NewFgSeenGeneration() {
static size_t fg_seen_generation = 0;
return ++fg_seen_generation;
}
const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
} // namespace mindspore
......@@ -289,6 +289,7 @@ class FuncGraph : public FuncGraphBase {
// parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_;
std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
size_t seen_;
std::list<CNodePtr> GetOrderedCnodes();
void EraseUnusedNodeInOrder(const AnfNodePtr &n);
......@@ -364,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP
return fg->NewCNode(inputs);
}
size_t NewFgSeenGeneration();
// Find the root cnodes of a segment of cnodes.
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
// Find the leaf cnodes of a segment of cnodes.
......
......@@ -755,8 +755,8 @@ void DepComputer::Recompute(const FuncGraphPtr &fg) {
}
}
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
if (path == nullptr || path->contains(fg)) {
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) {
if (fg->seen_ == seen_num) {
return std::make_shared<FuncGraphSet>();
}
FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
......@@ -770,9 +770,9 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
// Search the fv in fg's child func graph.
auto &fg_value_nodes = fg->func_graph_value_nodes();
for (auto &fg_value_node : fg_value_nodes) {
path->add(fg);
fg->seen_ = seen_num;
auto gt = GetValueNode<FuncGraphPtr>(fg_value_node.first);
parents->update(SeekParents(gt, path));
parents->update(SeekParents(gt, seen_num));
}
(void)parents->erase(fg);
return parents;
......@@ -780,7 +780,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(fg);
func_graph_parents_total_analysis_[fg].update(SeekParents(fg));
func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration()));
}
bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
......@@ -968,9 +968,8 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
}
}
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
MS_EXCEPTION_IF_NULL(path);
if (path->contains(fg)) {
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
if (fg->seen_ == seen_num) {
MS_LOG(DEBUG) << fg->ToString() << " had been checked";
return false;
}
......@@ -978,19 +977,20 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt
if (!j_fg_value_nodes.empty()) {
// check g1->J(fg)->g2->g cycle;
auto contains_j =
std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(),
[path](const std::pair<AnfNodePtr, int> iter) { return !path->contains(GetValueNode<FuncGraphPtr>(iter.first)); });
std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), [seen_num](const std::pair<AnfNodePtr, int> iter) {
return GetValueNode<FuncGraphPtr>(iter.first)->seen_ != seen_num;
});
if (contains_j != j_fg_value_nodes.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
return true;
}
}
path->add(fg);
fg->seen_ = seen_num;
// check if func graphs used contains J(func_graph);
for (auto &item : fg->func_graph_value_nodes()) {
auto used_g = GetValueNode<FuncGraphPtr>(item.first);
if (SeekJ(used_g, path)) {
if (SeekJ(used_g, seen_num)) {
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
return true;
}
......@@ -1000,7 +1000,6 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt
}
void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) {
std::shared_ptr<FuncGraphSet> path = std::make_shared<FuncGraphSet>();
this->j_total_analysis_[fg] = SeekJ(fg, path);
this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration());
}
} // namespace mindspore
......@@ -283,7 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer {
void RealRecompute(FuncGraphPtr fg) override;
private:
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>());
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, size_t seen_num);
};
using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
......@@ -423,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer {
void ExtraReset() override { j_total_analysis_.clear(); }
void RealRecompute(FuncGraphPtr fg) override;
bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path);
bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
};
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册