diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index ca8e27c428653f743b77f43faee41a1c72e1c3b7..c66fee2d13e3619716ab1889e37f5af220a436d1 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -38,6 +38,32 @@ namespace mindspore { using BaseRefCounterMap = OrderedMap; using FuncGraphCounterMap = OrderedMap; +struct CNodeIndexHasher { + std::size_t operator()(const CNodeIndexPairPtr pair) const { + MS_EXCEPTION_IF_NULL(pair); + MS_EXCEPTION_IF_NULL(pair->first); + return hash_combine(pair->first->hash(), std::hash()(pair->second)); + } +}; + +struct CNodeIndexEqual { + bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const { + if (lhs == nullptr || rhs == nullptr) { + return false; + } + if (lhs == rhs) { + return true; + } + if (lhs->first != rhs->first) { + return false; + } + if (lhs->second != rhs->second) { + return false; + } + return true; + } +}; + template , class CounterEqual = std::equal_to> using CounterOrderedMap = OrderedMap; using AnfNodeCounterMap = CounterOrderedMap; diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index 2a76cecd6465c064f58c15550669ee2687d2b2bc..291a752405f3151ecd684e677ce8dfdd6d77e843 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -633,103 +633,7 @@ void FuncGraphTransaction::Commit() { manager_->CommitChanges(changes); } -FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) - : manager_(manager), include_func_graph_none_(false) {} - -DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { - MS_EXCEPTION_IF_NULL(manager_); -} - -void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } - -void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } - -template -bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, - const ValueT &key, int count) { - auto &d = count_nodes_map_[func_graph]; - if (d.count(key) == 0) { - d[key] = count; - return true; - } else { - d[key] += count; - } - return false; -} - -template -bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, - const ValueT &key, int count) { - MS_EXCEPTION_IF_NULL(func_graph); - auto &d = count_nodes_map_[func_graph]; - if (d.count(key) != 0) { - if (d[key] == count) { - (void)d.erase(key); - return true; - } else { - d[key] -= count; - if (d[key] < 0) { - MS_LOG(EXCEPTION) << "Count of key '" << key - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); - } - } - } - return false; -} - -template -bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, - const ValueT &key, int count) { - if (count > 0) { - return Inc(func_graph, key, count); - } else if (count < 0) { - return Dec(func_graph, key, -count); - } else { - MS_LOG(EXCEPTION) << "Count of key '" << key - << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); - } -} - -bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { - auto &d = count_func_graphs_map_[func_graph]; - if (d.count(key) == 0) { - d[key] = count; - return true; - } else { - d[key] += count; - } - return false; -} - -bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { - auto &d = count_func_graphs_map_[func_graph]; - if (d.count(key) != 0) { - if (d[key] == count) { - (void)d.erase(key); - return true; - } else { - d[key] -= count; - if (d[key] < 0) { - MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); - } - } - } - return false; -} - -bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { - if (count > 0) { - return Inc(func_graph, key, count); - } else if (count < 0) { - return Dec(func_graph, key, -count); - } else { - MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() - << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); - } -} - -DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { +DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) { MS_EXCEPTION_IF_NULL(manager_); manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); validate_ = false; @@ -839,16 +743,15 @@ void FVTotalComputer::RealRecompute() { for (auto &fg : manager->func_graphs()) { fv_total_analysis_[fg] = OrderedMap(); - count_nodes_map_[fg] = OrderedMap(); - count_func_graphs_map_[fg] = OrderedMap(); } for (auto &fg : manager->func_graphs()) { + // add all free variable nodes AnfNodeCounterMap items = fg->free_variables(); for (auto &iter : items) { auto curr = fg; while (curr != nullptr) { - (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); + fv_total_analysis_[curr][iter.first] = iter.second; curr = manager->parent(curr); if (curr != nullptr) { const AnfNodeSet &all_nodes = curr->nodes(); @@ -859,6 +762,7 @@ void FVTotalComputer::RealRecompute() { } } + // add all FGs of free variables auto &used = fg->func_graphs_used(); for (auto &iter : used) { auto p = manager->parent(iter.first); @@ -867,21 +771,11 @@ void FVTotalComputer::RealRecompute() { } auto curr = fg; while (curr != p) { - (void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second); + fv_total_analysis_[curr][iter.first] = iter.second; curr = manager->parent(curr); } } } - for (auto &fg : manager->func_graphs()) { - auto &fvp = count_nodes_map_[fg]; - auto &fvg = count_func_graphs_map_[fg]; - for (auto &item : fvp) { - fv_total_analysis_[fg][item.first] = item.second; - } - for (auto &item : fvg) { - fv_total_analysis_[fg][item.first] = item.second; - } - } } void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index e4e5a1fba8b53f5509ef1bb2db2c79ed456ee679..5da3812d25920b26fd6ba209eb819f914eff26ac 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -88,14 +88,6 @@ FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool ma FuncGraphManagerPtr MakeManager(const std::vector &func_graphs = {}, bool manage = true); struct Signals { - Signal AddFuncGraph; - Signal DropFuncGraph; - Signal AddNode; - Signal DropNode; - Signal AddEdge; - Signal DropEdge; - Signal MoveAllCNode; - Signal InvalidateCollector; Signal InvalidateComputer; }; @@ -103,136 +95,15 @@ enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 }; using CNodeIndexPair = std::pair; using CNodeIndexPairPtr = std::shared_ptr; - -using FuncGraphToFuncGraphCounterMap = OrderedMap>; -template , class CollectorEqual = std::equal_to> -using FuncGraphToAnfNodeCounterMap = OrderedMap>; - -// analysis base class -class FuncGraphAnalysis { - public: - explicit FuncGraphAnalysis(const FuncGraphManager *const manager); - - virtual ~FuncGraphAnalysis() { manager_ = nullptr; } - - virtual size_t size() const { return 0; } - - virtual void OnAddFuncGraph(FuncGraphPtr) {} - - virtual void OnDropFuncGraph(FuncGraphPtr) {} - - virtual void OnMoveAllCNode(FuncGraphPtr, FuncGraphPtr) {} - - protected: - // subclass can reset their own member; - virtual void ExtraReset() {} - - virtual void OnAddNode(AnfNodePtr n) {} - - virtual void OnDropNode(AnfNodePtr n) {} - - virtual void OnAddEdge(AnfNodePtr, int, AnfNodePtr) {} - - virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {} - - const FuncGraphManager *manager_; - bool include_func_graph_none_; -}; - -using FuncGraphToAnfNodeMap = OrderedMap; - -struct CNodeIndexHasher { - std::size_t operator()(const CNodeIndexPairPtr pair) const { - MS_EXCEPTION_IF_NULL(pair); - MS_EXCEPTION_IF_NULL(pair->first); - return hash_combine(pair->first->hash(), std::hash()(pair->second)); - } -}; - -struct CNodeIndexEqual { - bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const { - if (lhs == nullptr || rhs == nullptr) { - return false; - } - if (lhs == rhs) { - return true; - } - if (lhs->first != rhs->first) { - return false; - } - if (lhs->second != rhs->second) { - return false; - } - return true; - } -}; - -// graphs analysis which compute in write, read needn't recompute -class DepCollector : public FuncGraphAnalysis { - public: - explicit DepCollector(const FuncGraphManager *manager); - ~DepCollector() override = default; - - void Reset() { ExtraReset(); } - void OnInvalidateCollector() { Reset(); } - - protected: - // inherit from FuncGraphAnalysis - void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; - void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; - // subclass can override; - virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {} -}; - -class CounterFuncGraphCollector : public DepCollector { - public: - explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} - ~CounterFuncGraphCollector() override = default; - FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } - // inherit from FuncGraphAnalysis - size_t size() const override { return count_func_graphs_map_.size(); } - void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap(); } - void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } - bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); - bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); - bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); - - FuncGraphToFuncGraphCounterMap count_func_graphs_map_; - - protected: - void ExtraReset() override { count_func_graphs_map_.clear(); } -}; - -template , class CollectorEqual = std::equal_to> -class CounterAnfNodeCollector : public DepCollector { - public: - explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} - ~CounterAnfNodeCollector() override = default; - FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } - - size_t size() const override { return count_nodes_map_.size(); } - void OnAddFuncGraph(FuncGraphPtr fg) final { - count_nodes_map_[fg] = OrderedMap(); - } - void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } - - bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); - bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); - bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); - - FuncGraphToAnfNodeCounterMap count_nodes_map_; - - protected: - void ExtraReset() override { count_nodes_map_.clear(); } -}; - using FuncGraphToFuncGraphSetMap = OrderedMap; -// graphs analysis which need dynamic compute by DepCollector in each read -class DepComputer : public FuncGraphAnalysis { +// analysis base class, graphs analysis which need dynamic compute by DepCollector in each read +class DepComputer { public: explicit DepComputer(const FuncGraphManager *manager); - ~DepComputer() override = default; + virtual ~DepComputer() { manager_ = nullptr; } + + virtual size_t size() const { return 0; } void Reset() { ExtraReset(); @@ -250,15 +121,14 @@ class DepComputer : public FuncGraphAnalysis { bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; } - void OnAddFuncGraph(FuncGraphPtr) final { Reset(); } - - void OnDropFuncGraph(FuncGraphPtr) final { Reset(); } - protected: + // subclass can reset their own member; + virtual void ExtraReset() {} // subclass do the real compute virtual void RealRecompute() {} virtual void RealRecompute(FuncGraphPtr) {} + const FuncGraphManager *manager_; bool validate_; OrderedMap func_graphs_validate_; @@ -345,12 +215,9 @@ class ScopeComputer final : public DepComputer { using FVTotalMap = OrderedMap>; -class FVTotalComputer final : public DepComputer, - public CounterAnfNodeCollector, - public CounterFuncGraphCollector { +class FVTotalComputer final : public DepComputer { public: - explicit FVTotalComputer(const FuncGraphManager *m) - : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} + explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} ~FVTotalComputer() override = default; FVTotalMap &fv_total_analysis() { return fv_total_analysis_; } diff --git a/tests/ut/cpp/ir/manager_test.cc b/tests/ut/cpp/ir/manager_test.cc index 7b1e4d8554f3411d226f0f28dfa703563b94c22b..04b584ec102f8abcf17cff20122cd6f860ab4dc5 100644 --- a/tests/ut/cpp/ir/manager_test.cc +++ b/tests/ut/cpp/ir/manager_test.cc @@ -104,7 +104,7 @@ class NestingSpecs { return name; } - void Check(std::shared_ptr results) { + void Check(std::shared_ptr results) { if (expected_.empty() && expected_recursive_.empty()) { return; } @@ -120,18 +120,6 @@ class NestingSpecs { CheckRecursive(recursive); return; } - - auto counter_g = dynamic_pointer_cast(results); - if (counter_g != nullptr) { - CheckGraphCounter(counter_g); - return; - } - - auto counter_p = dynamic_pointer_cast>(results); - if (counter_p != nullptr) { - CheckAnfNodeCounter(counter_p); - return; - } } private: @@ -193,59 +181,6 @@ class NestingSpecs { ASSERT_EQ(clean_results, expected_); } - // Add CheckNesting function - void CheckAnfNodeCounter(std::shared_ptr> results) { - std::map> clean_results; - for (auto& iter : results->count_nodes_map()) { - auto key = iter.first; - auto value = iter.second; - if (key == nullptr) { - continue; - } - std::string k = Name(key); - - std::set v; - for (auto& node : value) { - auto fg = node.first; - if (!Name(fg).empty()) { - v.insert(Name(fg)); - } - } - - if (!v.empty()) { - clean_results[k] = v; - } - } - - ASSERT_EQ(clean_results, expected_); - } - - void CheckGraphCounter(std::shared_ptr results) { - std::map> clean_results; - for (auto& iter : results->count_func_graphs_map()) { - auto key = iter.first; - auto value = iter.second; - if (key == nullptr) { - continue; - } - std::string k = Name(key); - - std::set v; - for (auto& node : value) { - auto fg = node.first; - if (!Name(fg).empty()) { - v.insert(Name(fg)); - } - } - - if (!v.empty()) { - clean_results[k] = v; - } - } - - ASSERT_EQ(clean_results, expected_); - } - void CheckRecursive(std::shared_ptr results) { std::map clean_results; for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {