提交 9efd2eed 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!926 [ir] add seen generation to accelerate traversing the whole graph

Merge pull request !926 from biffex/ir-add-seen-generation-to-accelerate-traverse-the-whole-graph
......@@ -227,6 +227,12 @@ bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
}
return false;
}
size_t NewSeenGeneration() {
static size_t seen_generation = 0;
return ++seen_generation;
}
namespace id_generator {
static std::unordered_map<std::string, int> node_ids;
std::string get_id(const AnfNodePtr &node) {
......
......@@ -155,6 +155,7 @@ class AnfNode : public Base {
os << node.ToString();
return os;
}
size_t seen_{0};
protected:
// Hold a weak ref to Graph as Graph also hold ref to AnfNode.
......@@ -429,6 +430,9 @@ inline S GetValueNode(const AnfNodePtr &node) {
auto s = value->cast<S>();
return s;
}
size_t NewSeenGeneration();
namespace id_generator {
std::string get_id(const AnfNodePtr &node);
void reset_id();
......
......@@ -90,20 +90,26 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode
bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
const SubstitutionPtr &transform) const {
#ifdef ENABLE_PROFILE
double start = GetTime();
#endif
FuncGraphManagerPtr manager = optimizer->manager();
std::unordered_set<AnfNodePtr> seen_node;
std::deque<AnfNodePtr> todo{root_node};
auto seen = NewSeenGeneration();
// 1024 is for the initial capacity of deque
std::deque<AnfNodePtr> todo(1024);
todo.push_back(root_node);
bool changes = false;
auto &all_nodes = manager->all_nodes();
while (!todo.empty()) {
AnfNodePtr node = todo.front();
todo.pop_front();
// check whether this node has been matched.
if (seen_node.find(node) != seen_node.end() || !manager->all_nodes().contains(node)) {
if (node == nullptr || node->seen_ == seen || !all_nodes.contains(node)) {
continue;
}
(void)seen_node.insert(node);
node->seen_ = seen;
// select nodes that this transform can be applied.
bool is_match = transform->predicate_(node);
......@@ -114,6 +120,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
auto ret = (*transform)(optimizer, node);
if (ret != nullptr && ret != node) {
change = true;
changes = true;
#ifdef ENABLE_PROFILE
double t = GetTime();
#endif
......@@ -139,16 +146,20 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
if (change && node_users.find(node) != node_users.end()) {
for (auto &use : node_users[node]) {
auto use_node = use.first;
if (use_node == nullptr) {
continue;
}
todo.push_back(use_node);
if (seen_node.find(use_node) != seen_node.end()) {
(void)seen_node.erase(use_node);
if (use_node->seen_ == seen) {
use_node->seen_--;
}
}
}
changes = changes || change;
}
#ifdef ENABLE_PROFILE
MsProfile::StatTime("opt.transform", GetTime() - start);
#endif
return changes;
}
......
......@@ -48,8 +48,8 @@ class Substitution {
PredicateFuncType predicate_{nullptr};
// an enum to mark this Substitution relation to renormalize pass
RenormAction renorm_action_;
explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action)
Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action)
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
~Substitution() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const;
......
......@@ -46,17 +46,18 @@ class DeepFirstSearcher : public AnfVisitor {
if (root == nullptr) {
return res_;
}
seen_ = NewSeenGeneration();
Visit(root);
return res_;
}
void Visit(const AnfNodePtr &node) override {
MS_EXCEPTION_IF_NULL(node);
if (seen_.count(node) != 0) {
if (node->seen_ == seen_) {
return;
}
(void)seen_.insert(node);
node->seen_ = seen_;
auto incl = include_(node);
if (incl == EXCLUDE) {
......@@ -70,9 +71,9 @@ class DeepFirstSearcher : public AnfVisitor {
}
private:
size_t seen_{0};
IncludeFunc include_;
std::vector<AnfNodePtr> res_{};
std::set<AnfNodePtr> seen_{};
};
class DeepScopedGraphSearcher : public DeepFirstSearcher {
......@@ -174,14 +175,14 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl
}
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) {
std::unordered_set<AnfNodePtr> done;
size_t seen = NewSeenGeneration();
std::list<AnfNodePtr> todo(1, root);
std::unordered_map<AnfNodePtr, size_t> rank;
std::vector<AnfNodePtr> res;
while (!todo.empty()) {
AnfNodePtr node = todo.back();
if (done.find(node) != done.end()) {
if (node == nullptr || node->seen_ == seen) {
todo.pop_back();
continue;
}
......@@ -194,7 +195,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
if (incl == FOLLOW) {
auto succs = succ(node);
for (const auto i : succs) {
if ((done.find(i) == done.end())
if ((i != nullptr && i->seen_ != seen)
// Handle the case for 2 subgraphs calls each other.
// If the ValueNodeGraph's return is already in the todo list, do not follow it.
&& !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) &&
......@@ -206,7 +207,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
} else if (incl == NOFOLLOW) {
// do nothing
} else if (incl == EXCLUDE) {
(void)done.insert(node);
node->seen_ = seen;
todo.pop_back();
continue;
} else {
......@@ -215,7 +216,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
if (cont) {
continue;
}
(void)done.insert(node);
node->seen_ = seen;
res.push_back(node);
todo.pop_back();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册