提交 848d1920 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1359 Optimize the IR modules.

Merge pull request !1359 from ZhangQinghua/master
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "utils/visible.h" #include "utils/visible.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/ordered_set.h" #include "utils/ordered_set.h"
#include "utils/ordered_map.h"
namespace mindspore { namespace mindspore {
template <typename T> template <typename T>
......
...@@ -47,6 +47,7 @@ FuncGraph::FuncGraph() ...@@ -47,6 +47,7 @@ FuncGraph::FuncGraph()
: flags_(), : flags_(),
transforms_(), transforms_(),
parameter_default_value_(), parameter_default_value_(),
seen_(0),
parameters_(), parameters_(),
has_vararg_(false), has_vararg_(false),
has_kwarg_(false), has_kwarg_(false),
...@@ -195,25 +196,93 @@ GraphDebugInfoPtr FuncGraph::debug_info() { ...@@ -195,25 +196,93 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
return this->debug_info_; return this->debug_info_;
} }
const AnfNodeSet &FuncGraph::nodes() { const AnfNodeSet &FuncGraph::nodes() { return nodes_; }
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); }
auto &nodes = mng->nodes();
return nodes[shared_from_base<FuncGraph>()]; void FuncGraph::ClearNodes() { nodes_.clear(); }
void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); }
void FuncGraph::DropNode(AnfNodePtr node) {
nodes_.erase(node);
auto graph = node->func_graph();
// Remove the node from order list.
if (graph) {
graph->EraseUnusedNodeInOrder(node);
}
} }
const AnfNodeCounterMap &FuncGraph::value_nodes() { const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; }
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) {
auto &cts = mng->valuenodes(); auto &others = source->value_nodes();
return cts[shared_from_base<FuncGraph>()]; for (auto it = others.begin(); it != others.end(); it++) {
AddValueNode(it->first, it->second);
}
} }
const AnfNodeCounterMap &FuncGraph::free_variables_direct() { void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); void FuncGraph::AddValueNode(AnfNodePtr node, int count) {
auto &fv_direct = mng->free_variables_direct(); if (value_nodes_.count(node) == 0) {
return fv_direct[shared_from_base<FuncGraph>()]; value_nodes_[node] = count;
} else {
value_nodes_[node] += count;
}
}
void FuncGraph::DropValueNode(AnfNodePtr node) {
if (value_nodes_.count(node) != 0) {
if (value_nodes_[node] == 1) {
(void)value_nodes_.erase(node);
} else {
value_nodes_[node]--;
if (value_nodes_[node] < 0) {
MS_LOG(EXCEPTION) << "Count of ValueNode '" << node
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
}
const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; }
void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) {
auto &others = source->free_variables();
for (auto it = others.begin(); it != others.end(); it++) {
if (it->first->func_graph().get() != this) {
(void)AddFreeVariable(it->first, it->second);
}
}
}
void FuncGraph::ClearFreeVariables() { free_variables_.clear(); }
bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) {
if (free_variables_.count(node) == 0) {
free_variables_[node] = count;
return true;
} else {
free_variables_[node] += count;
return false;
}
}
bool FuncGraph::DropFreeVariable(AnfNodePtr node) {
if (free_variables_.count(node) != 0) {
if (free_variables_[node] == 1) {
(void)free_variables_.erase(node);
return true;
} else {
free_variables_[node]--;
if (free_variables_[node] < 0) {
MS_LOG(EXCEPTION) << "Count of free variable '" << node
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
return false;
} }
const BaseRefCounterMap &FuncGraph::free_variables_total() { const BaseRefCounterMap &FuncGraph::free_variables_total() {
...@@ -249,11 +318,42 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() { ...@@ -249,11 +318,42 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
return func_graphs; return func_graphs;
} }
const FuncGraphCounterMap &FuncGraph::func_graphs_used() { const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; }
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) {
auto &used = mng->func_graphs_used(); auto &others = source->func_graphs_used();
return used[shared_from_base<FuncGraph>()]; for (auto it = others.begin(); it != others.end(); it++) {
(void)AddFuncGraphUsed(it->first, it->second);
}
func_graphs_used_.erase(source);
}
void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); }
bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) {
if (func_graphs_used_.count(fg) == 0) {
func_graphs_used_[fg] = count;
return true;
} else {
func_graphs_used_[fg] += count;
return false;
}
}
bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) {
if (func_graphs_used_.count(fg) != 0) {
if (func_graphs_used_[fg] == 1) {
(void)func_graphs_used_.erase(fg);
return true;
} else {
func_graphs_used_[fg]--;
if (func_graphs_used_[fg] < 0) {
MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
return false;
} }
const FuncGraphSet &FuncGraph::func_graphs_used_total() { const FuncGraphSet &FuncGraph::func_graphs_used_total() {
...@@ -263,15 +363,75 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { ...@@ -263,15 +363,75 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
return used; return used;
} }
const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; }
auto mng = manager_.lock();
if (mng == nullptr) { void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() auto &others = source->func_graph_cnodes_index();
<< " NodeInfo: " << trace::GetDebugInfo(debug_info()); for (auto it = others.begin(); it != others.end(); it++) {
// Ignore the user graph who may own itself.
auto fg = it->first->first->func_graph();
MS_EXCEPTION_IF_NULL(fg);
if (fg.get() != this) {
AddFuncGraphCNodeIndex(it->first, it->second);
}
}
}
void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); }
void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) {
if (func_graph_cnodes_index_.count(pair) == 0) {
func_graph_cnodes_index_[pair] = count;
} else {
func_graph_cnodes_index_[pair] += count;
}
}
void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) {
if (func_graph_cnodes_index_.count(pair) != 0) {
if (func_graph_cnodes_index_[pair] == 1) {
(void)func_graph_cnodes_index_.erase(pair);
} else {
func_graph_cnodes_index_[pair]--;
if (func_graph_cnodes_index_[pair] < 0) {
MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
}
const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; }
void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) {
auto &others = source->j_func_graphs();
for (auto it = others.begin(); it != others.end(); it++) {
AddJFuncGraph(it->first, it->second);
}
}
void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); }
void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) {
if (j_func_graphs_.count(fg) == 0) {
j_func_graphs_[fg] = count;
} else {
j_func_graphs_[fg] += count;
}
}
void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) {
if (j_func_graphs_.count(fg) != 0) {
if (j_func_graphs_[fg] == 1) {
(void)j_func_graphs_.erase(fg);
} else {
j_func_graphs_[fg]--;
if (j_func_graphs_[fg] < 0) {
MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
} }
MS_EXCEPTION_IF_NULL(mng);
auto &cnode = mng->func_graph_cnodes_index();
return cnode[shared_from_base<FuncGraph>()];
} }
FuncGraphPtr FuncGraph::parent() { FuncGraphPtr FuncGraph::parent() {
...@@ -662,10 +822,10 @@ void FuncGraph::EraseUnusedNodeInOrder() { ...@@ -662,10 +822,10 @@ void FuncGraph::EraseUnusedNodeInOrder() {
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
auto mng = manager_.lock(); auto mng = manager_.lock();
if (mng) { if (mng) {
auto nodes = mng->nodes()[shared_from_base<FuncGraph>()]; auto &all_nodes = nodes();
// Erase unused cnode. // Erase unused cnode.
for (auto it = order_.begin(); it != order_.end();) { for (auto it = order_.begin(); it != order_.end();) {
if (nodes.count(*it)) { if (all_nodes.count(*it)) {
(void)it++; (void)it++;
} else { } else {
MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order.";
...@@ -702,11 +862,11 @@ void FuncGraph::CheckOrder() { ...@@ -702,11 +862,11 @@ void FuncGraph::CheckOrder() {
} }
auto mng = manager_.lock(); auto mng = manager_.lock();
if (mng != nullptr) { if (mng != nullptr) {
const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()]; const auto &all_nodes = nodes();
if (nodes.size() != (order_.size() + parameters_.size())) { if (all_nodes.size() != (order_.size() + parameters_.size())) {
DumpCNodeList(); DumpCNodeList();
MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size "
<< nodes.size() - parameters_.size() << "."; << all_nodes.size() - parameters_.size() << ".";
} }
} }
MS_LOG(DEBUG) << "Check order okay."; MS_LOG(DEBUG) << "Check order okay.";
...@@ -840,6 +1000,11 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) { ...@@ -840,6 +1000,11 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
} }
} }
size_t NewFgSeenGeneration() {
static size_t fg_seen_generation = 0;
return ++fg_seen_generation;
}
const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph"); const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
const char kFuncGraphFlagUndetermined[] = "Undeterminate"; const char kFuncGraphFlagUndetermined[] = "Undeterminate";
} // namespace mindspore } // namespace mindspore
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <functional>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/manager.h" #include "ir/manager.h"
...@@ -36,8 +37,13 @@ ...@@ -36,8 +37,13 @@
namespace mindspore { namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>;
using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>; template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;
using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
...@@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase { ...@@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase {
// get all nodes belonging to this func graph // get all nodes belonging to this func graph
const AnfNodeSet &nodes(); const AnfNodeSet &nodes();
void CopyNodes(const FuncGraphPtr &source);
void ClearNodes();
void AddNode(AnfNodePtr node);
void DropNode(AnfNodePtr node);
// get all value_nodes belonging to this func graph // get all value_nodes belonging to this func graph
const AnfNodeCounterMap &value_nodes(); const AnfNodeCounterMap &value_nodes();
void CopyValueNodes(const FuncGraphPtr &source);
// get all vars directly pointed to in this func graph void ClearValueNodes();
const AnfNodeCounterMap &free_variables_direct(); void AddValueNode(AnfNodePtr node, int count = 1);
void DropValueNode(AnfNodePtr node);
// get all free vars directly used in this func graph
const AnfNodeCounterMap &free_variables();
void CopyFreeVariables(const FuncGraphPtr &source);
void ClearFreeVariables();
bool AddFreeVariable(AnfNodePtr node, int count = 1);
bool DropFreeVariable(AnfNodePtr node);
// get all vars required by this func graph // get all vars required by this func graph
const BaseRefCounterMap &free_variables_total(); const BaseRefCounterMap &free_variables_total();
...@@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase { ...@@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase {
// get all vars that are func graphs // get all vars that are func graphs
std::vector<FuncGraphPtr> free_variables_func_graphs(); std::vector<FuncGraphPtr> free_variables_func_graphs();
// get all func graphs directly used by this func graph // get all value nodes of func graph directly used by this func graph
const FuncGraphCounterMap &func_graphs_used(); const FuncGraphCounterMap &func_graphs_used();
void CopyFuncGraphsUsed(const FuncGraphPtr &source);
void ClearFuncGraphsUsed();
bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1);
bool DropFuncGraphUsed(FuncGraphPtr fg);
// get all value nodes of J func graph directly used by this func graph
const FuncGraphCounterMap &j_func_graphs();
void CopyJFuncGraphs(const FuncGraphPtr &source);
void ClearJFuncGraphs();
void AddJFuncGraph(FuncGraphPtr fg, int count = 1);
void DropJFuncGraph(FuncGraphPtr fg);
// get all func graphs nested used by this func graph // get all func graphs nested used by this func graph
const FuncGraphSet &func_graphs_used_total(); const FuncGraphSet &func_graphs_used_total();
// get all user value nodes of this func graph // get all user value nodes of this func graph, by CNode and its input's index
const CNodeIndexCounterMap &func_graph_cnodes_index(); const CNodeIndexCounterMap &func_graph_cnodes_index();
void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
void ClearFuncGraphCNodesIndex();
void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);
// Return the parent of this graph. // Return the parent of this graph.
FuncGraphPtr parent(); FuncGraphPtr parent();
...@@ -256,6 +289,7 @@ class FuncGraph : public FuncGraphBase { ...@@ -256,6 +289,7 @@ class FuncGraph : public FuncGraphBase {
// parameter default value // parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_; std::map<std::string, AnfNodePtr> parameter_default_value_;
std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_; std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
size_t seen_;
std::list<CNodePtr> GetOrderedCnodes(); std::list<CNodePtr> GetOrderedCnodes();
void EraseUnusedNodeInOrder(const AnfNodePtr &n); void EraseUnusedNodeInOrder(const AnfNodePtr &n);
...@@ -270,6 +304,24 @@ class FuncGraph : public FuncGraphBase { ...@@ -270,6 +304,24 @@ class FuncGraph : public FuncGraphBase {
// graph is manipulated by manager and others // graph is manipulated by manager and others
friend FuncGraphManager; friend FuncGraphManager;
// all nodes of the function
AnfNodeSet nodes_;
// all value nodes of the function
AnfNodeCounterMap value_nodes_;
// all func graph value nodes of the function
FuncGraphCounterMap func_graphs_used_;
// all free variables of the function
AnfNodeCounterMap free_variables_;
// all value nodes calling J in the function
FuncGraphCounterMap j_func_graphs_;
// all user value nodes of this func graph, recording by CNode and its input's index
CNodeIndexCounterMap func_graph_cnodes_index_;
// parameters of this function // parameters of this function
std::vector<AnfNodePtr> parameters_; std::vector<AnfNodePtr> parameters_;
std::vector<AnfNodePtr> paramter_obj_nodes_; std::vector<AnfNodePtr> paramter_obj_nodes_;
...@@ -313,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP ...@@ -313,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP
return fg->NewCNode(inputs); return fg->NewCNode(inputs);
} }
size_t NewFgSeenGeneration();
// Find the root cnodes of a segment of cnodes. // Find the root cnodes of a segment of cnodes.
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment); std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
// Find the leaf cnodes of a segment of cnodes. // Find the leaf cnodes of a segment of cnodes.
......
...@@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { ...@@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
if (!clone_all_valuenodes_) { if (!clone_all_valuenodes_) {
return; return;
} }
auto &value_nodes = manager_->valuenodes()[func_graph]; auto &value_nodes = func_graph->value_nodes();
for (auto &value_node : value_nodes) { for (auto &value_node : value_nodes) {
auto old_node = value_node.first; auto old_node = value_node.first;
MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(old_node);
...@@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { ...@@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
if (!clone_all_used_graphs_) { if (!clone_all_used_graphs_) {
return; return;
} }
auto &used_graphs = manager_->func_graphs_used()[func_graph]; auto &used = func_graph->func_graphs_used();
for (auto &used_graph : used_graphs) { for (auto &fg : used) {
todo_.push_back({used_graph.first, nullptr, {}}); todo_.push_back({fg.first, nullptr, {}});
} }
} }
...@@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func ...@@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
} }
target_func_graph->set_return(return_node); target_func_graph->set_return(return_node);
auto &cnodes = manager_->func_graph_cnodes_index()[func_graph]; auto &cnodes = func_graph->func_graph_cnodes_index();
for (auto &cnode : cnodes) { for (auto &cnode : cnodes) {
auto parent = cnode.first->first->cast<CNodePtr>(); auto parent = cnode.first->first->cast<CNodePtr>();
auto valuenode = parent->input(cnode.first->second); auto valuenode = parent->input(cnode.first->second);
...@@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t ...@@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
const AnfNodeSet &nodes = manager_->nodes()[func_graph]; const AnfNodeSet &nodes = func_graph->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
CloneNode(node, target_func_graph); CloneNode(node, target_func_graph);
} }
......
此差异已折叠。
...@@ -140,44 +140,6 @@ class FuncGraphAnalysis { ...@@ -140,44 +140,6 @@ class FuncGraphAnalysis {
using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>; using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>;
// graphs analysis which compute in write, read needn't recompute
class DepCollector : public FuncGraphAnalysis {
public:
explicit DepCollector(const FuncGraphManager *manager);
~DepCollector() override = default;
void Reset() { ExtraReset(); }
void OnInvalidateCollector() { Reset(); }
protected:
// inherit from FuncGraphAnalysis
void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
// subclass can override;
virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
};
class NodesCollector final : public DepCollector {
public:
explicit NodesCollector(const FuncGraphManager *m);
~NodesCollector() override = default;
const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; }
size_t size() const override { return nodes_analysis_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); }
void OnDropFuncGraph(FuncGraphPtr fg) override { (void)nodes_analysis_.erase(fg); }
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
FuncGraphToAnfNodeMap nodes_analysis_;
protected:
void ExtraReset() override { nodes_analysis_.clear(); }
void OnAddNode(AnfNodePtr n) override;
void OnDropNode(AnfNodePtr n) override;
};
struct CNodeIndexHasher { struct CNodeIndexHasher {
std::size_t operator()(const CNodeIndexPairPtr pair) const { std::size_t operator()(const CNodeIndexPairPtr pair) const {
MS_EXCEPTION_IF_NULL(pair); MS_EXCEPTION_IF_NULL(pair);
...@@ -204,59 +166,21 @@ struct CNodeIndexEqual { ...@@ -204,59 +166,21 @@ struct CNodeIndexEqual {
} }
}; };
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>> // graphs analysis which compute in write, read needn't recompute
class CounterAnfNodeCollector : public DepCollector { class DepCollector : public FuncGraphAnalysis {
public:
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterAnfNodeCollector() override = default;
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
size_t size() const override { return count_nodes_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final {
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
}
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
protected:
void ExtraReset() override { count_nodes_map_.clear(); }
};
class ValueNodesCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
public:
explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~ValueNodesCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
// Record the CNode and its input index, who points to the function graph.
class FuncGraphUsersCNodeIndexCollector final
: public CounterAnfNodeCollector<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> {
public: public:
explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} explicit DepCollector(const FuncGraphManager *manager);
~FuncGraphUsersCNodeIndexCollector() override = default; ~DepCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
class FVDirectCollector final : public CounterAnfNodeCollector<AnfNodePtr> { void Reset() { ExtraReset(); }
public: void OnInvalidateCollector() { Reset(); }
explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FVDirectCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
protected: protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; // inherit from FuncGraphAnalysis
void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
// subclass can override;
virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
}; };
class CounterFuncGraphCollector : public DepCollector { class CounterFuncGraphCollector : public DepCollector {
...@@ -278,50 +202,27 @@ class CounterFuncGraphCollector : public DepCollector { ...@@ -278,50 +202,27 @@ class CounterFuncGraphCollector : public DepCollector {
void ExtraReset() override { count_func_graphs_map_.clear(); } void ExtraReset() override { count_func_graphs_map_.clear(); }
}; };
class FuncGraphChildDirect final : public CounterFuncGraphCollector { template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
public: class CounterAnfNodeCollector : public DepCollector {
explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphChildDirect() override = default;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
// graph's all parents, parentsdirect have a map, which key is graph, value is this graph's all direct and proxy
// parents:
// 1.proxy parent: graph g use graph f, key is g, value is ParentProxy(f) because f's parent will be g's parent
// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f
class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
public: public:
explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
~FuncGraphParentsDirectCollector() override = default; ~CounterAnfNodeCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
// graph's all used graphs: key is g, value is g used graph size_t size() const override { return count_nodes_map_.size(); }
class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { void OnAddFuncGraph(FuncGraphPtr fg) final {
public: count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} }
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
~FuncGraphsUsedCollector() override = default;
protected: bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
}; bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
public:
explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override;
~FuncGraphJDirectCollector() override = default;
protected: protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; void ExtraReset() override { count_nodes_map_.clear(); }
}; };
using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>; using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
...@@ -367,8 +268,8 @@ class DepComputer : public FuncGraphAnalysis { ...@@ -367,8 +268,8 @@ class DepComputer : public FuncGraphAnalysis {
// graph g's all direct or proxy parents // graph g's all direct or proxy parents
class FuncGraphParentsTotalComputer final : public DepComputer { class FuncGraphParentsTotalComputer final : public DepComputer {
public: public:
explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {} explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } ~FuncGraphParentsTotalComputer() override = default;
FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }
...@@ -382,10 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer { ...@@ -382,10 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer {
void RealRecompute(FuncGraphPtr fg) override; void RealRecompute(FuncGraphPtr fg) override;
private: private:
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>()); FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, size_t seen_num);
// when SeekParents calls itself recursively, it can access these variables by class member
// other than pass by formal parameters, it can save 1 parameter for SeekParents().
FuncGraphToFuncGraphCounterMap *all_parents_direct_;
}; };
using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
...@@ -525,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer { ...@@ -525,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer {
void ExtraReset() override { j_total_analysis_.clear(); } void ExtraReset() override { j_total_analysis_.clear(); }
void RealRecompute(FuncGraphPtr fg) override; void RealRecompute(FuncGraphPtr fg) override;
bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path); bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
}; };
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
...@@ -562,30 +460,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { ...@@ -562,30 +460,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
NodeUsersMap &node_users() { return node_users_; } NodeUsersMap &node_users() { return node_users_; }
FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; }
FuncGraphToAnfNodeCounterMap<AnfNodePtr> &valuenodes() const { return valuenodes_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap<AnfNodePtr> &free_variables_direct() const {
return free_variables_direct_->count_nodes_map_;
}
FuncGraphToAnfNodeCounterMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> &func_graph_cnodes_index() const {
return func_graph_cnodes_index_->count_nodes_map_;
}
FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }
FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const {
return func_graph_child_direct_->count_func_graphs_map_;
}
FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const {
return func_graph_parents_direct_->count_func_graphs_map_;
}
FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; }
FVTotalMap &free_variables_total() const; FVTotalMap &free_variables_total() const;
FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
...@@ -610,14 +484,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { ...@@ -610,14 +484,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
// Static Analysis // Static Analysis
NodeUsersMap node_users_; NodeUsersMap node_users_;
AnfNodeSet all_nodes_; // managed nodes AnfNodeSet all_nodes_; // managed nodes
std::shared_ptr<NodesCollector> nodes_;
std::shared_ptr<ValueNodesCollector> valuenodes_;
std::shared_ptr<FVDirectCollector> free_variables_direct_;
std::shared_ptr<FuncGraphUsersCNodeIndexCollector> func_graph_cnodes_index_;
std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_;
std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_;
std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_;
std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_;
// Dynamic Analysis // Dynamic Analysis
std::shared_ptr<ParentComputer> func_graph_parent_; std::shared_ptr<ParentComputer> func_graph_parent_;
...@@ -630,6 +496,9 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { ...@@ -630,6 +496,9 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes); FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes);
void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges,
Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms); Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms);
void AddEdge(AnfNodePtr node, int index, AnfNodePtr input);
void DropEdge(AnfNodePtr node, int index, AnfNodePtr input);
void MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target);
FuncGraphSet roots_; // managed roots FuncGraphSet roots_; // managed roots
FuncGraphSet func_graphs_; // managed func graphs FuncGraphSet func_graphs_; // managed func graphs
......
...@@ -492,7 +492,7 @@ void DFunctor::MapParamObject() { ...@@ -492,7 +492,7 @@ void DFunctor::MapParamObject() {
void DFunctor::MapValueObject() { void DFunctor::MapValueObject() {
// Map ValueNode. // Map ValueNode.
auto manager = resources_->manager(); auto manager = resources_->manager();
auto &value_nodes = manager->valuenodes()[primal_graph_]; auto &value_nodes = primal_graph_->value_nodes();
for (const auto &value_pair : value_nodes) { for (const auto &value_pair : value_nodes) {
auto node = value_pair.first; auto node = value_pair.first;
auto parent_adjoint = FindAdjoint(node); auto parent_adjoint = FindAdjoint(node);
......
...@@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes( ...@@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node; std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node;
// record the node input to be replaced // record the node input to be replaced
NodeInputReplMap repl_node_inputs; NodeInputReplMap repl_node_inputs;
const AnfNodeSet &nodes = manager->nodes()[graph]; const AnfNodeSet &nodes = graph->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
...@@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode( ...@@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode(
ResetSharedOp(); ResetSharedOp();
std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node = std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node =
std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced
const AnfNodeSet &nodes = manager->nodes()[graph]; const AnfNodeSet &nodes = graph->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
......
...@@ -391,7 +391,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { ...@@ -391,7 +391,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
auto manager = res->manager(); auto manager = res->manager();
// Remove duplicated value nodes, due to replace operation, can't use reference. // Remove duplicated value nodes, due to replace operation, can't use reference.
auto value_nodes = manager->valuenodes()[func_graph]; auto value_nodes = func_graph->value_nodes();
HashCache hash_cache; HashCache hash_cache;
HashValue hashes; HashValue hashes;
for (const auto &value_pair : value_nodes) { for (const auto &value_pair : value_nodes) {
......
...@@ -488,12 +488,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { ...@@ -488,12 +488,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
void TraverseGraphMap( void TraverseGraphMap(
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr,
const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts, const FuncGraphSet &fgs,
const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(manager_ptr);
MS_EXCEPTION_IF_NULL(tr); MS_EXCEPTION_IF_NULL(tr);
for (const auto &ct_graphs : cts) { for (const auto &fg : fgs) {
for (const auto &ct_any : ct_graphs.second) { for (const auto &ct_any : fg->value_nodes()) {
AnfNodePtr const_primitive_node = ct_any.first; AnfNodePtr const_primitive_node = ct_any.first;
if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) { if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
auto users = manager_ptr->node_users()[const_primitive_node]; auto users = manager_ptr->node_users()[const_primitive_node];
...@@ -553,8 +553,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { ...@@ -553,8 +553,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
}; };
FuncGraphTransaction tr = manager_ptr->Transact(); FuncGraphTransaction tr = manager_ptr->Transact();
auto &cts = manager_ptr->valuenodes(); auto &fgs = manager_ptr->func_graphs();
TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph); TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
return graph; return graph;
} }
......
...@@ -132,18 +132,6 @@ class NestingSpecs { ...@@ -132,18 +132,6 @@ class NestingSpecs {
CheckAnfNodeCounter(counter_p); CheckAnfNodeCounter(counter_p);
return; return;
} }
auto counter_pair = dynamic_pointer_cast<CounterAnfNodeCollector<CNodeIndexPairPtr>>(results);
if (counter_pair != nullptr) {
CheckCNodeIndexPairCounter(counter_pair);
return;
}
auto nodes = dynamic_pointer_cast<NodesCollector>(results);
if (nodes != nullptr) {
CheckNodes(nodes);
return;
}
} }
private: private:
...@@ -205,33 +193,7 @@ class NestingSpecs { ...@@ -205,33 +193,7 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_); ASSERT_EQ(clean_results, expected_);
} }
void CheckNodes(std::shared_ptr<NodesCollector> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->nodes_analysis()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);
std::set<std::string> v;
for (auto& node : value) {
if (!node->isa<CNode>() && !Name(node).empty()) {
v.insert(Name(node));
}
}
if (!v.empty()) {
clean_results[k] = v;
}
}
ASSERT_EQ(clean_results, expected_);
}
// Add CheckNesting function // Add CheckNesting function
void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) { void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
std::map<std::string, std::set<std::string>> clean_results; std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) { for (auto& iter : results->count_nodes_map()) {
...@@ -258,32 +220,6 @@ class NestingSpecs { ...@@ -258,32 +220,6 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_); ASSERT_EQ(clean_results, expected_);
} }
void CheckCNodeIndexPairCounter(std::shared_ptr<CounterAnfNodeCollector<CNodeIndexPairPtr>> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);
std::set<std::string> v;
for (auto& node : value) {
auto fg = node.first->first;
if (!Name(fg).empty()) {
v.insert(Name(fg));
}
}
if (!v.empty()) {
clean_results[k] = v;
}
}
ASSERT_EQ(clean_results, expected_);
}
void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) { void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
std::map<std::string, std::set<std::string>> clean_results; std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_func_graphs_map()) { for (auto& iter : results->count_func_graphs_map()) {
...@@ -471,17 +407,10 @@ std::vector<FuncGraphPtr> MakeNestedGraph2() { ...@@ -471,17 +407,10 @@ std::vector<FuncGraphPtr> MakeNestedGraph2() {
} }
// Add TestManager::CheckManager function to checkout the result // Add TestManager::CheckManager function to checkout the result
void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) { void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
auto size = mng->func_graphs().size(); auto size = mng->func_graphs().size();
ASSERT_EQ(size + 1, mng->nodes().size());
ASSERT_EQ(size, mng->free_variables_total().size()); ASSERT_EQ(size, mng->free_variables_total().size());
ASSERT_EQ(size, mng->valuenodes().size());
ASSERT_EQ(size, mng->free_variables_direct().size());
ASSERT_EQ(size, mng->func_graph_cnodes_index().size());
ASSERT_EQ(size, mng->func_graph_parents_direct().size());
ASSERT_EQ(size, mng->func_graphs_used().size());
} }
TEST_F(TestManager, test_scalar_add_manual) { TEST_F(TestManager, test_scalar_add_manual) {
...@@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) { ...@@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
auto nodes = mng->nodes(); ASSERT_EQ(2, f->nodes().size());
ASSERT_EQ(3, nodes[nullptr].size()); ASSERT_EQ(1, g->nodes().size());
ASSERT_EQ(2, nodes[f].size());
ASSERT_EQ(1, nodes[g].size());
auto users = mng->node_users(); auto users = mng->node_users();
for (auto& iter : users) { for (auto& iter : users) {
ASSERT_EQ(1, iter.second.size()); ASSERT_EQ(1, iter.second.size());
} }
auto graphs_used = mng->func_graphs_used(); ASSERT_EQ(1, f->func_graphs_used().size());
ASSERT_EQ(1, graphs_used[f].size()); ASSERT_EQ(0, g->func_graphs_used().size());
ASSERT_EQ(0, graphs_used[g].size());
auto fv_direct = mng->free_variables_direct(); ASSERT_EQ(0, f->free_variables().size());
ASSERT_EQ(0, fv_direct[f].size()); ASSERT_EQ(1, g->free_variables().size());
ASSERT_EQ(1, fv_direct[g].size());
auto fv_total = mng->free_variables_total(); auto fv_total = mng->free_variables_total();
ASSERT_EQ(0, fv_total[f].size()); ASSERT_EQ(0, fv_total[f].size());
ASSERT_EQ(1, fv_total[g].size()); ASSERT_EQ(1, fv_total[g].size());
auto cnodes = mng->func_graph_cnodes_index(); ASSERT_EQ(0, f->func_graph_cnodes_index().size());
ASSERT_EQ(0, cnodes[f].size()); ASSERT_EQ(1, g->func_graph_cnodes_index().size());
ASSERT_EQ(1, cnodes[g].size());
} }
TEST_F(TestManager, test_deep_nested2_manual) { TEST_F(TestManager, test_deep_nested2_manual) {
...@@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) { ...@@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) {
ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(3, mng->func_graphs().size());
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
ASSERT_EQ(4, mng->nodes().size()); ASSERT_EQ(4, gfn->nodes().size());
ASSERT_EQ(20, mng->all_nodes().size()); ASSERT_EQ(20, mng->all_nodes().size());
ASSERT_EQ(25, mng->node_users().size()); ASSERT_EQ(25, mng->node_users().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
...@@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) { ...@@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) {
ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(3, mng->func_graphs().size());
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
ASSERT_EQ(4, mng->nodes().size());
ASSERT_EQ(20, mng->all_nodes().size()); ASSERT_EQ(20, mng->all_nodes().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
} }
...@@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) { ...@@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) {
FuncGraphPtr fg = getPyFun("ir_get_fn"); FuncGraphPtr fg = getPyFun("ir_get_fn");
auto mng = Manage(fg); auto mng = Manage(fg);
const FuncGraphToAnfNodeMap& nodes = mng->nodes(); const auto &fgs = mng->func_graphs();
ASSERT_TRUE(nodes.find(fg) != nodes.end()); ASSERT_TRUE(fgs.contains(fg));
FuncGraphSet s; FuncGraphSet s;
s.add(fg); s.add(fg);
mng->MaybeDropFuncGraphs(s); mng->MaybeDropFuncGraphs(s);
ASSERT_TRUE(nodes.find(fg) != nodes.end()); ASSERT_TRUE(fgs.contains(fg));
} }
TEST_F(TestManager, test_keep_roots) { TEST_F(TestManager, test_keep_roots) {
......
...@@ -26,15 +26,14 @@ ...@@ -26,15 +26,14 @@
namespace mindspore { namespace mindspore {
void CheckNoFreeVariables(FuncGraphPtr root) { void CheckNoFreeVariables(FuncGraphPtr root) {
auto mng = Manage(root); auto mng = Manage(root);
for (auto &iter : mng->nodes()) { for (auto &iter : mng->func_graphs()) {
auto g = iter.first; auto g = iter;
auto nodes = iter.second;
if (g == nullptr) { if (g == nullptr) {
continue; continue;
} }
ASSERT_TRUE(g->parent() == nullptr); ASSERT_TRUE(g->parent() == nullptr);
auto nodes = g->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
ASSERT_EQ(node->func_graph(), g); ASSERT_EQ(node->func_graph(), g);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册