提交 9efd2eed 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!926 [ir] add seen generation to accelerate traversing the whole graph

Merge pull request !926 from biffex/ir-add-seen-generation-to-accelerate-traverse-the-whole-graph
...@@ -227,6 +227,12 @@ bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { ...@@ -227,6 +227,12 @@ bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
} }
return false; return false;
} }
size_t NewSeenGeneration() {
static size_t seen_generation = 0;
return ++seen_generation;
}
namespace id_generator { namespace id_generator {
static std::unordered_map<std::string, int> node_ids; static std::unordered_map<std::string, int> node_ids;
std::string get_id(const AnfNodePtr &node) { std::string get_id(const AnfNodePtr &node) {
......
...@@ -155,6 +155,7 @@ class AnfNode : public Base { ...@@ -155,6 +155,7 @@ class AnfNode : public Base {
os << node.ToString(); os << node.ToString();
return os; return os;
} }
size_t seen_{0};
protected: protected:
// Hold a weak ref to Graph as Graph also hold ref to AnfNode. // Hold a weak ref to Graph as Graph also hold ref to AnfNode.
...@@ -429,6 +430,9 @@ inline S GetValueNode(const AnfNodePtr &node) { ...@@ -429,6 +430,9 @@ inline S GetValueNode(const AnfNodePtr &node) {
auto s = value->cast<S>(); auto s = value->cast<S>();
return s; return s;
} }
size_t NewSeenGeneration();
namespace id_generator { namespace id_generator {
std::string get_id(const AnfNodePtr &node); std::string get_id(const AnfNodePtr &node);
void reset_id(); void reset_id();
......
...@@ -90,20 +90,26 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode ...@@ -90,20 +90,26 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode
bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
const SubstitutionPtr &transform) const { const SubstitutionPtr &transform) const {
#ifdef ENABLE_PROFILE
double start = GetTime();
#endif
FuncGraphManagerPtr manager = optimizer->manager(); FuncGraphManagerPtr manager = optimizer->manager();
std::unordered_set<AnfNodePtr> seen_node; auto seen = NewSeenGeneration();
std::deque<AnfNodePtr> todo{root_node}; // 1024 is for the initial capacity of deque
std::deque<AnfNodePtr> todo(1024);
todo.push_back(root_node);
bool changes = false; bool changes = false;
auto &all_nodes = manager->all_nodes();
while (!todo.empty()) { while (!todo.empty()) {
AnfNodePtr node = todo.front(); AnfNodePtr node = todo.front();
todo.pop_front(); todo.pop_front();
// check whether this node has been matched. // check whether this node has been matched.
if (seen_node.find(node) != seen_node.end() || !manager->all_nodes().contains(node)) { if (node == nullptr || node->seen_ == seen || !all_nodes.contains(node)) {
continue; continue;
} }
(void)seen_node.insert(node); node->seen_ = seen;
// select nodes that this transform can be applied. // select nodes that this transform can be applied.
bool is_match = transform->predicate_(node); bool is_match = transform->predicate_(node);
...@@ -114,6 +120,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo ...@@ -114,6 +120,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
auto ret = (*transform)(optimizer, node); auto ret = (*transform)(optimizer, node);
if (ret != nullptr && ret != node) { if (ret != nullptr && ret != node) {
change = true; change = true;
changes = true;
#ifdef ENABLE_PROFILE #ifdef ENABLE_PROFILE
double t = GetTime(); double t = GetTime();
#endif #endif
...@@ -139,16 +146,20 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo ...@@ -139,16 +146,20 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
if (change && node_users.find(node) != node_users.end()) { if (change && node_users.find(node) != node_users.end()) {
for (auto &use : node_users[node]) { for (auto &use : node_users[node]) {
auto use_node = use.first; auto use_node = use.first;
if (use_node == nullptr) {
continue;
}
todo.push_back(use_node); todo.push_back(use_node);
if (seen_node.find(use_node) != seen_node.end()) { if (use_node->seen_ == seen) {
(void)seen_node.erase(use_node); use_node->seen_--;
} }
} }
} }
changes = changes || change;
} }
#ifdef ENABLE_PROFILE
MsProfile::StatTime("opt.transform", GetTime() - start);
#endif
return changes; return changes;
} }
......
...@@ -48,8 +48,8 @@ class Substitution { ...@@ -48,8 +48,8 @@ class Substitution {
PredicateFuncType predicate_{nullptr}; PredicateFuncType predicate_{nullptr};
// an enum to mark this Substitution relation to renormalize pass // an enum to mark this Substitution relation to renormalize pass
RenormAction renorm_action_; RenormAction renorm_action_;
explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action) const RenormAction &renorm_action)
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
~Substitution() = default; ~Substitution() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const;
......
...@@ -46,17 +46,18 @@ class DeepFirstSearcher : public AnfVisitor { ...@@ -46,17 +46,18 @@ class DeepFirstSearcher : public AnfVisitor {
if (root == nullptr) { if (root == nullptr) {
return res_; return res_;
} }
seen_ = NewSeenGeneration();
Visit(root); Visit(root);
return res_; return res_;
} }
void Visit(const AnfNodePtr &node) override { void Visit(const AnfNodePtr &node) override {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (seen_.count(node) != 0) { if (node->seen_ == seen_) {
return; return;
} }
(void)seen_.insert(node); node->seen_ = seen_;
auto incl = include_(node); auto incl = include_(node);
if (incl == EXCLUDE) { if (incl == EXCLUDE) {
...@@ -70,9 +71,9 @@ class DeepFirstSearcher : public AnfVisitor { ...@@ -70,9 +71,9 @@ class DeepFirstSearcher : public AnfVisitor {
} }
private: private:
size_t seen_{0};
IncludeFunc include_; IncludeFunc include_;
std::vector<AnfNodePtr> res_{}; std::vector<AnfNodePtr> res_{};
std::set<AnfNodePtr> seen_{};
}; };
class DeepScopedGraphSearcher : public DeepFirstSearcher { class DeepScopedGraphSearcher : public DeepFirstSearcher {
...@@ -174,14 +175,14 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl ...@@ -174,14 +175,14 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl
} }
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) {
std::unordered_set<AnfNodePtr> done; size_t seen = NewSeenGeneration();
std::list<AnfNodePtr> todo(1, root); std::list<AnfNodePtr> todo(1, root);
std::unordered_map<AnfNodePtr, size_t> rank; std::unordered_map<AnfNodePtr, size_t> rank;
std::vector<AnfNodePtr> res; std::vector<AnfNodePtr> res;
while (!todo.empty()) { while (!todo.empty()) {
AnfNodePtr node = todo.back(); AnfNodePtr node = todo.back();
if (done.find(node) != done.end()) { if (node == nullptr || node->seen_ == seen) {
todo.pop_back(); todo.pop_back();
continue; continue;
} }
...@@ -194,7 +195,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c ...@@ -194,7 +195,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
if (incl == FOLLOW) { if (incl == FOLLOW) {
auto succs = succ(node); auto succs = succ(node);
for (const auto i : succs) { for (const auto i : succs) {
if ((done.find(i) == done.end()) if ((i != nullptr && i->seen_ != seen)
// Handle the case for 2 subgraphs calls each other. // Handle the case for 2 subgraphs calls each other.
// If the ValueNodeGraph's return is already in the todo list, do not follow it. // If the ValueNodeGraph's return is already in the todo list, do not follow it.
&& !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) && && !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) &&
...@@ -206,7 +207,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c ...@@ -206,7 +207,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
} else if (incl == NOFOLLOW) { } else if (incl == NOFOLLOW) {
// do nothing // do nothing
} else if (incl == EXCLUDE) { } else if (incl == EXCLUDE) {
(void)done.insert(node); node->seen_ = seen;
todo.pop_back(); todo.pop_back();
continue; continue;
} else { } else {
...@@ -215,7 +216,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c ...@@ -215,7 +216,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
if (cont) { if (cont) {
continue; continue;
} }
(void)done.insert(node); node->seen_ = seen;
res.push_back(node); res.push_back(node);
todo.pop_back(); todo.pop_back();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册