Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
848d1920
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
848d1920
编写于
5月 25, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 25, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1359 Optimize the IR modules.
Merge pull request !1359 from ZhangQinghua/master
上级
a3d9c9a8
dbb86cb1
变更
12
展开全部
隐藏空白更改
内联
并排
Showing
12 changed file
with
438 addition
and
587 deletion
+438
-587
mindspore/ccsrc/ir/base.h
mindspore/ccsrc/ir/base.h
+1
-0
mindspore/ccsrc/ir/func_graph.cc
mindspore/ccsrc/ir/func_graph.cc
+198
-33
mindspore/ccsrc/ir/func_graph.h
mindspore/ccsrc/ir/func_graph.h
+61
-7
mindspore/ccsrc/ir/func_graph_cloner.cc
mindspore/ccsrc/ir/func_graph_cloner.cc
+6
-6
mindspore/ccsrc/ir/manager.cc
mindspore/ccsrc/ir/manager.cc
+115
-275
mindspore/ccsrc/ir/manager.h
mindspore/ccsrc/ir/manager.h
+33
-164
mindspore/ccsrc/optimizer/ad/dfunctor.cc
mindspore/ccsrc/optimizer/ad/dfunctor.cc
+1
-1
mindspore/ccsrc/optimizer/irpass/branch_culling.cc
mindspore/ccsrc/optimizer/irpass/branch_culling.cc
+2
-2
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+1
-1
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+5
-5
tests/ut/cpp/ir/manager_test.cc
tests/ut/cpp/ir/manager_test.cc
+12
-89
tests/ut/cpp/optimizer/cconv_test.cc
tests/ut/cpp/optimizer/cconv_test.cc
+3
-4
未找到文件。
mindspore/ccsrc/ir/base.h
浏览文件 @
848d1920
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include "utils/visible.h"
#include "utils/visible.h"
#include "utils/log_adapter.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"
#include "utils/ordered_set.h"
#include "utils/ordered_map.h"
namespace
mindspore
{
namespace
mindspore
{
template
<
typename
T
>
template
<
typename
T
>
...
...
mindspore/ccsrc/ir/func_graph.cc
浏览文件 @
848d1920
...
@@ -47,6 +47,7 @@ FuncGraph::FuncGraph()
...
@@ -47,6 +47,7 @@ FuncGraph::FuncGraph()
:
flags_
(),
:
flags_
(),
transforms_
(),
transforms_
(),
parameter_default_value_
(),
parameter_default_value_
(),
seen_
(
0
),
parameters_
(),
parameters_
(),
has_vararg_
(
false
),
has_vararg_
(
false
),
has_kwarg_
(
false
),
has_kwarg_
(
false
),
...
@@ -195,25 +196,93 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
...
@@ -195,25 +196,93 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
return
this
->
debug_info_
;
return
this
->
debug_info_
;
}
}
const
AnfNodeSet
&
FuncGraph
::
nodes
()
{
const
AnfNodeSet
&
FuncGraph
::
nodes
()
{
return
nodes_
;
}
auto
mng
=
manager_
.
lock
();
MS_EXCEPTION_IF_NULL
(
mng
);
void
FuncGraph
::
CopyNodes
(
const
FuncGraphPtr
&
source
)
{
nodes_
=
source
->
nodes
();
}
auto
&
nodes
=
mng
->
nodes
();
return
nodes
[
shared_from_base
<
FuncGraph
>
()];
void
FuncGraph
::
ClearNodes
()
{
nodes_
.
clear
();
}
void
FuncGraph
::
AddNode
(
AnfNodePtr
node
)
{
nodes_
.
add
(
node
);
}
void
FuncGraph
::
DropNode
(
AnfNodePtr
node
)
{
nodes_
.
erase
(
node
);
auto
graph
=
node
->
func_graph
();
// Remove the node from order list.
if
(
graph
)
{
graph
->
EraseUnusedNodeInOrder
(
node
);
}
}
}
const
AnfNodeCounterMap
&
FuncGraph
::
value_nodes
()
{
const
AnfNodeCounterMap
&
FuncGraph
::
value_nodes
()
{
return
value_nodes_
;
}
auto
mng
=
manager_
.
lock
();
MS_EXCEPTION_IF_NULL
(
mng
);
void
FuncGraph
::
CopyValueNodes
(
const
FuncGraphPtr
&
source
)
{
auto
&
cts
=
mng
->
valuenodes
();
auto
&
others
=
source
->
value_nodes
();
return
cts
[
shared_from_base
<
FuncGraph
>
()];
for
(
auto
it
=
others
.
begin
();
it
!=
others
.
end
();
it
++
)
{
AddValueNode
(
it
->
first
,
it
->
second
);
}
}
}
const
AnfNodeCounterMap
&
FuncGraph
::
free_variables_direct
()
{
void
FuncGraph
::
ClearValueNodes
()
{
value_nodes_
.
clear
();
}
auto
mng
=
manager_
.
lock
();
MS_EXCEPTION_IF_NULL
(
mng
);
void
FuncGraph
::
AddValueNode
(
AnfNodePtr
node
,
int
count
)
{
auto
&
fv_direct
=
mng
->
free_variables_direct
();
if
(
value_nodes_
.
count
(
node
)
==
0
)
{
return
fv_direct
[
shared_from_base
<
FuncGraph
>
()];
value_nodes_
[
node
]
=
count
;
}
else
{
value_nodes_
[
node
]
+=
count
;
}
}
void
FuncGraph
::
DropValueNode
(
AnfNodePtr
node
)
{
if
(
value_nodes_
.
count
(
node
)
!=
0
)
{
if
(
value_nodes_
[
node
]
==
1
)
{
(
void
)
value_nodes_
.
erase
(
node
);
}
else
{
value_nodes_
[
node
]
--
;
if
(
value_nodes_
[
node
]
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Count of ValueNode '"
<<
node
<<
"' dec from 0. NodeInfo: "
<<
trace
::
GetDebugInfo
(
debug_info
());
}
}
}
}
const
AnfNodeCounterMap
&
FuncGraph
::
free_variables
()
{
return
free_variables_
;
}
void
FuncGraph
::
CopyFreeVariables
(
const
FuncGraphPtr
&
source
)
{
auto
&
others
=
source
->
free_variables
();
for
(
auto
it
=
others
.
begin
();
it
!=
others
.
end
();
it
++
)
{
if
(
it
->
first
->
func_graph
().
get
()
!=
this
)
{
(
void
)
AddFreeVariable
(
it
->
first
,
it
->
second
);
}
}
}
void
FuncGraph
::
ClearFreeVariables
()
{
free_variables_
.
clear
();
}
bool
FuncGraph
::
AddFreeVariable
(
AnfNodePtr
node
,
int
count
)
{
if
(
free_variables_
.
count
(
node
)
==
0
)
{
free_variables_
[
node
]
=
count
;
return
true
;
}
else
{
free_variables_
[
node
]
+=
count
;
return
false
;
}
}
bool
FuncGraph
::
DropFreeVariable
(
AnfNodePtr
node
)
{
if
(
free_variables_
.
count
(
node
)
!=
0
)
{
if
(
free_variables_
[
node
]
==
1
)
{
(
void
)
free_variables_
.
erase
(
node
);
return
true
;
}
else
{
free_variables_
[
node
]
--
;
if
(
free_variables_
[
node
]
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Count of free variable '"
<<
node
<<
"' dec from 0. NodeInfo: "
<<
trace
::
GetDebugInfo
(
debug_info
());
}
}
}
return
false
;
}
}
const
BaseRefCounterMap
&
FuncGraph
::
free_variables_total
()
{
const
BaseRefCounterMap
&
FuncGraph
::
free_variables_total
()
{
...
@@ -249,11 +318,42 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
...
@@ -249,11 +318,42 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
return
func_graphs
;
return
func_graphs
;
}
}
const
FuncGraphCounterMap
&
FuncGraph
::
func_graphs_used
()
{
const
FuncGraphCounterMap
&
FuncGraph
::
func_graphs_used
()
{
return
func_graphs_used_
;
}
auto
mng
=
manager_
.
lock
();
MS_EXCEPTION_IF_NULL
(
mng
);
void
FuncGraph
::
CopyFuncGraphsUsed
(
const
FuncGraphPtr
&
source
)
{
auto
&
used
=
mng
->
func_graphs_used
();
auto
&
others
=
source
->
func_graphs_used
();
return
used
[
shared_from_base
<
FuncGraph
>
()];
for
(
auto
it
=
others
.
begin
();
it
!=
others
.
end
();
it
++
)
{
(
void
)
AddFuncGraphUsed
(
it
->
first
,
it
->
second
);
}
func_graphs_used_
.
erase
(
source
);
}
void
FuncGraph
::
ClearFuncGraphsUsed
()
{
func_graphs_used_
.
clear
();
}
bool
FuncGraph
::
AddFuncGraphUsed
(
FuncGraphPtr
fg
,
int
count
)
{
if
(
func_graphs_used_
.
count
(
fg
)
==
0
)
{
func_graphs_used_
[
fg
]
=
count
;
return
true
;
}
else
{
func_graphs_used_
[
fg
]
+=
count
;
return
false
;
}
}
bool
FuncGraph
::
DropFuncGraphUsed
(
FuncGraphPtr
fg
)
{
if
(
func_graphs_used_
.
count
(
fg
)
!=
0
)
{
if
(
func_graphs_used_
[
fg
]
==
1
)
{
(
void
)
func_graphs_used_
.
erase
(
fg
);
return
true
;
}
else
{
func_graphs_used_
[
fg
]
--
;
if
(
func_graphs_used_
[
fg
]
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Count of FuncGraph '"
<<
fg
<<
"' dec from 0. NodeInfo: "
<<
trace
::
GetDebugInfo
(
debug_info
());
}
}
}
return
false
;
}
}
const
FuncGraphSet
&
FuncGraph
::
func_graphs_used_total
()
{
const
FuncGraphSet
&
FuncGraph
::
func_graphs_used_total
()
{
...
@@ -263,15 +363,75 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
...
@@ -263,15 +363,75 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
return
used
;
return
used
;
}
}
const
CNodeIndexCounterMap
&
FuncGraph
::
func_graph_cnodes_index
()
{
const
CNodeIndexCounterMap
&
FuncGraph
::
func_graph_cnodes_index
()
{
return
func_graph_cnodes_index_
;
}
auto
mng
=
manager_
.
lock
();
if
(
mng
==
nullptr
)
{
void
FuncGraph
::
CopyFuncGraphCNodesIndex
(
const
FuncGraphPtr
&
source
)
{
MS_LOG
(
EXCEPTION
)
<<
"BUG: no manager for this func graph: "
<<
ToString
()
auto
&
others
=
source
->
func_graph_cnodes_index
();
<<
" NodeInfo: "
<<
trace
::
GetDebugInfo
(
debug_info
());
for
(
auto
it
=
others
.
begin
();
it
!=
others
.
end
();
it
++
)
{
// Ignore the user graph who may own itself.
auto
fg
=
it
->
first
->
first
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
fg
);
if
(
fg
.
get
()
!=
this
)
{
AddFuncGraphCNodeIndex
(
it
->
first
,
it
->
second
);
}
}
}
void
FuncGraph
::
ClearFuncGraphCNodesIndex
()
{
func_graph_cnodes_index_
.
clear
();
}
void
FuncGraph
::
AddFuncGraphCNodeIndex
(
CNodeIndexPairPtr
pair
,
int
count
)
{
if
(
func_graph_cnodes_index_
.
count
(
pair
)
==
0
)
{
func_graph_cnodes_index_
[
pair
]
=
count
;
}
else
{
func_graph_cnodes_index_
[
pair
]
+=
count
;
}
}
void
FuncGraph
::
DropFuncGraphCNodeIndex
(
CNodeIndexPairPtr
pair
)
{
if
(
func_graph_cnodes_index_
.
count
(
pair
)
!=
0
)
{
if
(
func_graph_cnodes_index_
[
pair
]
==
1
)
{
(
void
)
func_graph_cnodes_index_
.
erase
(
pair
);
}
else
{
func_graph_cnodes_index_
[
pair
]
--
;
if
(
func_graph_cnodes_index_
[
pair
]
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Count of CNode/Index '"
<<
pair
->
first
<<
"/"
<<
pair
->
second
<<
"' dec from 0. NodeInfo: "
<<
trace
::
GetDebugInfo
(
debug_info
());
}
}
}
}
const
FuncGraphCounterMap
&
FuncGraph
::
j_func_graphs
()
{
return
j_func_graphs_
;
}
void
FuncGraph
::
CopyJFuncGraphs
(
const
FuncGraphPtr
&
source
)
{
auto
&
others
=
source
->
j_func_graphs
();
for
(
auto
it
=
others
.
begin
();
it
!=
others
.
end
();
it
++
)
{
AddJFuncGraph
(
it
->
first
,
it
->
second
);
}
}
void
FuncGraph
::
ClearJFuncGraphs
()
{
j_func_graphs_
.
clear
();
}
void
FuncGraph
::
AddJFuncGraph
(
FuncGraphPtr
fg
,
int
count
)
{
if
(
j_func_graphs_
.
count
(
fg
)
==
0
)
{
j_func_graphs_
[
fg
]
=
count
;
}
else
{
j_func_graphs_
[
fg
]
+=
count
;
}
}
void
FuncGraph
::
DropJFuncGraph
(
FuncGraphPtr
fg
)
{
if
(
j_func_graphs_
.
count
(
fg
)
!=
0
)
{
if
(
j_func_graphs_
[
fg
]
==
1
)
{
(
void
)
j_func_graphs_
.
erase
(
fg
);
}
else
{
j_func_graphs_
[
fg
]
--
;
if
(
j_func_graphs_
[
fg
]
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Count of J FuncGraph '"
<<
fg
<<
"' dec from 0. NodeInfo: "
<<
trace
::
GetDebugInfo
(
debug_info
());
}
}
}
}
MS_EXCEPTION_IF_NULL
(
mng
);
auto
&
cnode
=
mng
->
func_graph_cnodes_index
();
return
cnode
[
shared_from_base
<
FuncGraph
>
()];
}
}
FuncGraphPtr
FuncGraph
::
parent
()
{
FuncGraphPtr
FuncGraph
::
parent
()
{
...
@@ -662,10 +822,10 @@ void FuncGraph::EraseUnusedNodeInOrder() {
...
@@ -662,10 +822,10 @@ void FuncGraph::EraseUnusedNodeInOrder() {
if
(
has_flag
(
GRAPH_FLAG_HAS_EFFECT
))
{
if
(
has_flag
(
GRAPH_FLAG_HAS_EFFECT
))
{
auto
mng
=
manager_
.
lock
();
auto
mng
=
manager_
.
lock
();
if
(
mng
)
{
if
(
mng
)
{
auto
nodes
=
mng
->
nodes
()[
shared_from_base
<
FuncGraph
>
()]
;
auto
&
all_nodes
=
nodes
()
;
// Erase unused cnode.
// Erase unused cnode.
for
(
auto
it
=
order_
.
begin
();
it
!=
order_
.
end
();)
{
for
(
auto
it
=
order_
.
begin
();
it
!=
order_
.
end
();)
{
if
(
nodes
.
count
(
*
it
))
{
if
(
all_
nodes
.
count
(
*
it
))
{
(
void
)
it
++
;
(
void
)
it
++
;
}
else
{
}
else
{
MS_LOG
(
DEBUG
)
<<
"Remove node "
<<
(
*
it
)
->
ToString
()
<<
" in graph "
<<
ToString
()
<<
" order."
;
MS_LOG
(
DEBUG
)
<<
"Remove node "
<<
(
*
it
)
->
ToString
()
<<
" in graph "
<<
ToString
()
<<
" order."
;
...
@@ -702,11 +862,11 @@ void FuncGraph::CheckOrder() {
...
@@ -702,11 +862,11 @@ void FuncGraph::CheckOrder() {
}
}
auto
mng
=
manager_
.
lock
();
auto
mng
=
manager_
.
lock
();
if
(
mng
!=
nullptr
)
{
if
(
mng
!=
nullptr
)
{
const
auto
&
nodes
=
mng
->
nodes
()[
shared_from_base
<
FuncGraph
>
()]
;
const
auto
&
all_nodes
=
nodes
()
;
if
(
nodes
.
size
()
!=
(
order_
.
size
()
+
parameters_
.
size
()))
{
if
(
all_
nodes
.
size
()
!=
(
order_
.
size
()
+
parameters_
.
size
()))
{
DumpCNodeList
();
DumpCNodeList
();
MS_LOG
(
EXCEPTION
)
<<
"CNode order size "
<<
order_
.
size
()
<<
" is not equal to managed node size "
MS_LOG
(
EXCEPTION
)
<<
"CNode order size "
<<
order_
.
size
()
<<
" is not equal to managed node size "
<<
nodes
.
size
()
-
parameters_
.
size
()
<<
"."
;
<<
all_
nodes
.
size
()
-
parameters_
.
size
()
<<
"."
;
}
}
}
}
MS_LOG
(
DEBUG
)
<<
"Check order okay."
;
MS_LOG
(
DEBUG
)
<<
"Check order okay."
;
...
@@ -840,6 +1000,11 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
...
@@ -840,6 +1000,11 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
}
}
}
}
size_t
NewFgSeenGeneration
()
{
static
size_t
fg_seen_generation
=
0
;
return
++
fg_seen_generation
;
}
const
PrimitivePtr
FuncGraphTransform
::
func_graph_prim_
=
std
::
make_shared
<
Primitive
>
(
"FuncGraph"
);
const
PrimitivePtr
FuncGraphTransform
::
func_graph_prim_
=
std
::
make_shared
<
Primitive
>
(
"FuncGraph"
);
const
char
kFuncGraphFlagUndetermined
[]
=
"Undeterminate"
;
const
char
kFuncGraphFlagUndetermined
[]
=
"Undeterminate"
;
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/ir/func_graph.h
浏览文件 @
848d1920
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <memory>
#include <memory>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <functional>
#include "ir/anf.h"
#include "ir/anf.h"
#include "ir/manager.h"
#include "ir/manager.h"
...
@@ -36,8 +37,13 @@
...
@@ -36,8 +37,13 @@
namespace
mindspore
{
namespace
mindspore
{
using
BaseRefCounterMap
=
OrderedMap
<
BaseRef
,
int
,
BaseRefHash
>
;
using
BaseRefCounterMap
=
OrderedMap
<
BaseRef
,
int
,
BaseRefHash
>
;
using
FuncGraphCounterMap
=
OrderedMap
<
FuncGraphPtr
,
int
>
;
using
FuncGraphCounterMap
=
OrderedMap
<
FuncGraphPtr
,
int
>
;
using
AnfNodeCounterMap
=
OrderedMap
<
AnfNodePtr
,
int
>
;
using
CNodeIndexCounterMap
=
OrderedMap
<
CNodeIndexPairPtr
,
int
,
CNodeIndexHasher
,
CNodeIndexEqual
>
;
template
<
typename
ValueT
,
class
CounterHash
=
std
::
hash
<
ValueT
>,
class
CounterEqual
=
std
::
equal_to
<
ValueT
>>
using
CounterOrderedMap
=
OrderedMap
<
ValueT
,
int
,
CounterHash
,
CounterEqual
>
;
using
AnfNodeCounterMap
=
CounterOrderedMap
<
AnfNodePtr
>
;
using
CNodeIndexCounterMap
=
CounterOrderedMap
<
CNodeIndexPairPtr
,
CNodeIndexHasher
,
CNodeIndexEqual
>
;
using
FuncGraphMap
=
OrderedMap
<
FuncGraphPtr
,
int
>
;
const
char
FUNC_GRAPH_FLAG_IGNORE_VALUES
[]
=
"ignore_values"
;
const
char
FUNC_GRAPH_FLAG_IGNORE_VALUES
[]
=
"ignore_values"
;
const
char
FUNC_GRAPH_FLAG_DEFER_INLINE
[]
=
"defer_inline"
;
const
char
FUNC_GRAPH_FLAG_DEFER_INLINE
[]
=
"defer_inline"
;
...
@@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase {
...
@@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase {
// get all nodes belonging to this func graph
// get all nodes belonging to this func graph
const
AnfNodeSet
&
nodes
();
const
AnfNodeSet
&
nodes
();
void
CopyNodes
(
const
FuncGraphPtr
&
source
);
void
ClearNodes
();
void
AddNode
(
AnfNodePtr
node
);
void
DropNode
(
AnfNodePtr
node
);
// get all value_nodes belonging to this func graph
// get all value_nodes belonging to this func graph
const
AnfNodeCounterMap
&
value_nodes
();
const
AnfNodeCounterMap
&
value_nodes
();
void
CopyValueNodes
(
const
FuncGraphPtr
&
source
);
// get all vars directly pointed to in this func graph
void
ClearValueNodes
();
const
AnfNodeCounterMap
&
free_variables_direct
();
void
AddValueNode
(
AnfNodePtr
node
,
int
count
=
1
);
void
DropValueNode
(
AnfNodePtr
node
);
// get all free vars directly used in this func graph
const
AnfNodeCounterMap
&
free_variables
();
void
CopyFreeVariables
(
const
FuncGraphPtr
&
source
);
void
ClearFreeVariables
();
bool
AddFreeVariable
(
AnfNodePtr
node
,
int
count
=
1
);
bool
DropFreeVariable
(
AnfNodePtr
node
);
// get all vars required by this func graph
// get all vars required by this func graph
const
BaseRefCounterMap
&
free_variables_total
();
const
BaseRefCounterMap
&
free_variables_total
();
...
@@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase {
...
@@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase {
// get all vars that are func graphs
// get all vars that are func graphs
std
::
vector
<
FuncGraphPtr
>
free_variables_func_graphs
();
std
::
vector
<
FuncGraphPtr
>
free_variables_func_graphs
();
// get all
func graphs
directly used by this func graph
// get all
value nodes of func graph
directly used by this func graph
const
FuncGraphCounterMap
&
func_graphs_used
();
const
FuncGraphCounterMap
&
func_graphs_used
();
void
CopyFuncGraphsUsed
(
const
FuncGraphPtr
&
source
);
void
ClearFuncGraphsUsed
();
bool
AddFuncGraphUsed
(
FuncGraphPtr
fg
,
int
count
=
1
);
bool
DropFuncGraphUsed
(
FuncGraphPtr
fg
);
// get all value nodes of J func graph directly used by this func graph
const
FuncGraphCounterMap
&
j_func_graphs
();
void
CopyJFuncGraphs
(
const
FuncGraphPtr
&
source
);
void
ClearJFuncGraphs
();
void
AddJFuncGraph
(
FuncGraphPtr
fg
,
int
count
=
1
);
void
DropJFuncGraph
(
FuncGraphPtr
fg
);
// get all func graphs nested used by this func graph
// get all func graphs nested used by this func graph
const
FuncGraphSet
&
func_graphs_used_total
();
const
FuncGraphSet
&
func_graphs_used_total
();
// get all user value nodes of this func graph
// get all user value nodes of this func graph
, by CNode and its input's index
const
CNodeIndexCounterMap
&
func_graph_cnodes_index
();
const
CNodeIndexCounterMap
&
func_graph_cnodes_index
();
void
CopyFuncGraphCNodesIndex
(
const
FuncGraphPtr
&
source
);
void
ClearFuncGraphCNodesIndex
();
void
AddFuncGraphCNodeIndex
(
CNodeIndexPairPtr
node
,
int
count
=
1
);
void
DropFuncGraphCNodeIndex
(
CNodeIndexPairPtr
node
);
// Return the parent of this graph.
// Return the parent of this graph.
FuncGraphPtr
parent
();
FuncGraphPtr
parent
();
...
@@ -256,6 +289,7 @@ class FuncGraph : public FuncGraphBase {
...
@@ -256,6 +289,7 @@ class FuncGraph : public FuncGraphBase {
// parameter default value
// parameter default value
std
::
map
<
std
::
string
,
AnfNodePtr
>
parameter_default_value_
;
std
::
map
<
std
::
string
,
AnfNodePtr
>
parameter_default_value_
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
make_ref_params_
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
make_ref_params_
;
size_t
seen_
;
std
::
list
<
CNodePtr
>
GetOrderedCnodes
();
std
::
list
<
CNodePtr
>
GetOrderedCnodes
();
void
EraseUnusedNodeInOrder
(
const
AnfNodePtr
&
n
);
void
EraseUnusedNodeInOrder
(
const
AnfNodePtr
&
n
);
...
@@ -270,6 +304,24 @@ class FuncGraph : public FuncGraphBase {
...
@@ -270,6 +304,24 @@ class FuncGraph : public FuncGraphBase {
// graph is manipulated by manager and others
// graph is manipulated by manager and others
friend
FuncGraphManager
;
friend
FuncGraphManager
;
// all nodes of the function
AnfNodeSet
nodes_
;
// all value nodes of the function
AnfNodeCounterMap
value_nodes_
;
// all func graph value nodes of the function
FuncGraphCounterMap
func_graphs_used_
;
// all free variables of the function
AnfNodeCounterMap
free_variables_
;
// all value nodes calling J in the function
FuncGraphCounterMap
j_func_graphs_
;
// all user value nodes of this func graph, recording by CNode and its input's index
CNodeIndexCounterMap
func_graph_cnodes_index_
;
// parameters of this function
// parameters of this function
std
::
vector
<
AnfNodePtr
>
parameters_
;
std
::
vector
<
AnfNodePtr
>
parameters_
;
std
::
vector
<
AnfNodePtr
>
paramter_obj_nodes_
;
std
::
vector
<
AnfNodePtr
>
paramter_obj_nodes_
;
...
@@ -313,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP
...
@@ -313,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP
return
fg
->
NewCNode
(
inputs
);
return
fg
->
NewCNode
(
inputs
);
}
}
size_t
NewFgSeenGeneration
();
// Find the root cnodes of a segment of cnodes.
// Find the root cnodes of a segment of cnodes.
std
::
shared_ptr
<
OrderedSet
<
CNodePtr
>>
FindRoots
(
const
std
::
vector
<
CNodePtr
>
&
segment
);
std
::
shared_ptr
<
OrderedSet
<
CNodePtr
>>
FindRoots
(
const
std
::
vector
<
CNodePtr
>
&
segment
);
// Find the leaf cnodes of a segment of cnodes.
// Find the leaf cnodes of a segment of cnodes.
...
...
mindspore/ccsrc/ir/func_graph_cloner.cc
浏览文件 @
848d1920
...
@@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
...
@@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
if
(
!
clone_all_valuenodes_
)
{
if
(
!
clone_all_valuenodes_
)
{
return
;
return
;
}
}
auto
&
value_nodes
=
manager_
->
valuenodes
()[
func_graph
]
;
auto
&
value_nodes
=
func_graph
->
value_nodes
()
;
for
(
auto
&
value_node
:
value_nodes
)
{
for
(
auto
&
value_node
:
value_nodes
)
{
auto
old_node
=
value_node
.
first
;
auto
old_node
=
value_node
.
first
;
MS_EXCEPTION_IF_NULL
(
old_node
);
MS_EXCEPTION_IF_NULL
(
old_node
);
...
@@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
...
@@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
if
(
!
clone_all_used_graphs_
)
{
if
(
!
clone_all_used_graphs_
)
{
return
;
return
;
}
}
auto
&
used
_graphs
=
manager_
->
func_graphs_used
()[
func_graph
]
;
auto
&
used
=
func_graph
->
func_graphs_used
()
;
for
(
auto
&
used_graph
:
used_graphs
)
{
for
(
auto
&
fg
:
used
)
{
todo_
.
push_back
({
used_graph
.
first
,
nullptr
,
{}});
todo_
.
push_back
({
fg
.
first
,
nullptr
,
{}});
}
}
}
}
...
@@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
...
@@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
}
}
target_func_graph
->
set_return
(
return_node
);
target_func_graph
->
set_return
(
return_node
);
auto
&
cnodes
=
manager_
->
func_graph_cnodes_index
()[
func_graph
]
;
auto
&
cnodes
=
func_graph
->
func_graph_cnodes_index
()
;
for
(
auto
&
cnode
:
cnodes
)
{
for
(
auto
&
cnode
:
cnodes
)
{
auto
parent
=
cnode
.
first
->
first
->
cast
<
CNodePtr
>
();
auto
parent
=
cnode
.
first
->
first
->
cast
<
CNodePtr
>
();
auto
valuenode
=
parent
->
input
(
cnode
.
first
->
second
);
auto
valuenode
=
parent
->
input
(
cnode
.
first
->
second
);
...
@@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t
...
@@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
target_func_graph
);
MS_EXCEPTION_IF_NULL
(
target_func_graph
);
MS_EXCEPTION_IF_NULL
(
manager_
);
MS_EXCEPTION_IF_NULL
(
manager_
);
const
AnfNodeSet
&
nodes
=
manager_
->
nodes
()[
func_graph
]
;
const
AnfNodeSet
&
nodes
=
func_graph
->
nodes
()
;
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
CloneNode
(
node
,
target_func_graph
);
CloneNode
(
node
,
target_func_graph
);
}
}
...
...
mindspore/ccsrc/ir/manager.cc
浏览文件 @
848d1920
此差异已折叠。
点击以展开。
mindspore/ccsrc/ir/manager.h
浏览文件 @
848d1920
...
@@ -140,44 +140,6 @@ class FuncGraphAnalysis {
...
@@ -140,44 +140,6 @@ class FuncGraphAnalysis {
using
FuncGraphToAnfNodeMap
=
OrderedMap
<
FuncGraphPtr
,
AnfNodeSet
>
;
using
FuncGraphToAnfNodeMap
=
OrderedMap
<
FuncGraphPtr
,
AnfNodeSet
>
;
// graphs analysis which compute in write, read needn't recompute
class
DepCollector
:
public
FuncGraphAnalysis
{
public:
explicit
DepCollector
(
const
FuncGraphManager
*
manager
);
~
DepCollector
()
override
=
default
;
void
Reset
()
{
ExtraReset
();
}
void
OnInvalidateCollector
()
{
Reset
();
}
protected:
// inherit from FuncGraphAnalysis
void
OnAddEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
)
override
;
void
OnDropEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
)
override
;
// subclass can override;
virtual
void
OnModEdge
(
AnfNodePtr
,
int
,
AnfNodePtr
,
EdgeProcessDirection
)
{}
};
class
NodesCollector
final
:
public
DepCollector
{
public:
explicit
NodesCollector
(
const
FuncGraphManager
*
m
);
~
NodesCollector
()
override
=
default
;
const
FuncGraphToAnfNodeMap
&
nodes_analysis
()
const
{
return
nodes_analysis_
;
}
size_t
size
()
const
override
{
return
nodes_analysis_
.
size
();
}
void
OnAddFuncGraph
(
FuncGraphPtr
fg
)
override
{
nodes_analysis_
[
fg
]
=
AnfNodeSet
();
}
void
OnDropFuncGraph
(
FuncGraphPtr
fg
)
override
{
(
void
)
nodes_analysis_
.
erase
(
fg
);
}
void
OnMoveAllCNode
(
FuncGraphPtr
src
,
FuncGraphPtr
dst
)
override
;
FuncGraphToAnfNodeMap
nodes_analysis_
;
protected:
void
ExtraReset
()
override
{
nodes_analysis_
.
clear
();
}
void
OnAddNode
(
AnfNodePtr
n
)
override
;
void
OnDropNode
(
AnfNodePtr
n
)
override
;
};
struct
CNodeIndexHasher
{
struct
CNodeIndexHasher
{
std
::
size_t
operator
()(
const
CNodeIndexPairPtr
pair
)
const
{
std
::
size_t
operator
()(
const
CNodeIndexPairPtr
pair
)
const
{
MS_EXCEPTION_IF_NULL
(
pair
);
MS_EXCEPTION_IF_NULL
(
pair
);
...
@@ -204,59 +166,21 @@ struct CNodeIndexEqual {
...
@@ -204,59 +166,21 @@ struct CNodeIndexEqual {
}
}
};
};
template
<
typename
ValueT
,
class
CollectorHash
=
std
::
hash
<
ValueT
>,
class
CollectorEqual
=
std
::
equal_to
<
ValueT
>>
// graphs analysis which compute in write, read needn't recompute
class
CounterAnfNodeCollector
:
public
DepCollector
{
class
DepCollector
:
public
FuncGraphAnalysis
{
public:
explicit
CounterAnfNodeCollector
(
const
FuncGraphManager
*
m
)
:
DepCollector
(
m
)
{}
~
CounterAnfNodeCollector
()
override
=
default
;
FuncGraphToAnfNodeCounterMap
<
ValueT
,
CollectorHash
,
CollectorEqual
>
&
count_nodes_map
()
{
return
count_nodes_map_
;
}
size_t
size
()
const
override
{
return
count_nodes_map_
.
size
();
}
void
OnAddFuncGraph
(
FuncGraphPtr
fg
)
final
{
count_nodes_map_
[
fg
]
=
OrderedMap
<
ValueT
,
int
,
CollectorHash
,
CollectorEqual
>
();
}
void
OnDropFuncGraph
(
FuncGraphPtr
fg
)
final
{
(
void
)
count_nodes_map_
.
erase
(
fg
);
}
bool
Inc
(
const
FuncGraphPtr
&
func_graph
,
const
ValueT
&
key
,
int
count
);
bool
Dec
(
const
FuncGraphPtr
&
func_graph
,
const
ValueT
&
key
,
int
count
);
bool
Mod
(
const
FuncGraphPtr
&
func_graph
,
const
ValueT
&
key
,
int
count
);
FuncGraphToAnfNodeCounterMap
<
ValueT
,
CollectorHash
,
CollectorEqual
>
count_nodes_map_
;
protected:
void
ExtraReset
()
override
{
count_nodes_map_
.
clear
();
}
};
class
ValueNodesCollector
final
:
public
CounterAnfNodeCollector
<
AnfNodePtr
>
{
public:
explicit
ValueNodesCollector
(
const
FuncGraphManager
*
m
)
:
CounterAnfNodeCollector
(
m
)
{}
~
ValueNodesCollector
()
override
=
default
;
void
OnMoveAllCNode
(
FuncGraphPtr
src
,
FuncGraphPtr
dst
)
override
;
protected:
void
OnModEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
,
EdgeProcessDirection
direction
)
override
;
};
// Record the CNode and its input index, who points to the function graph.
class
FuncGraphUsersCNodeIndexCollector
final
:
public
CounterAnfNodeCollector
<
CNodeIndexPairPtr
,
CNodeIndexHasher
,
CNodeIndexEqual
>
{
public:
public:
explicit
FuncGraphUsersCNodeIndexCollector
(
const
FuncGraphManager
*
m
)
:
CounterAnfNodeCollector
(
m
)
{}
explicit
DepCollector
(
const
FuncGraphManager
*
manager
);
~
FuncGraphUsersCNodeIndexCollector
()
override
=
default
;
~
DepCollector
()
override
=
default
;
void
OnMoveAllCNode
(
FuncGraphPtr
src
,
FuncGraphPtr
dst
)
override
;
protected:
void
OnModEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
,
EdgeProcessDirection
direction
)
override
;
};
class
FVDirectCollector
final
:
public
CounterAnfNodeCollector
<
AnfNodePtr
>
{
void
Reset
()
{
ExtraReset
();
}
public:
void
OnInvalidateCollector
()
{
Reset
();
}
explicit
FVDirectCollector
(
const
FuncGraphManager
*
m
)
:
CounterAnfNodeCollector
(
m
)
{}
~
FVDirectCollector
()
override
=
default
;
void
OnMoveAllCNode
(
FuncGraphPtr
src
,
FuncGraphPtr
dst
)
override
;
protected:
protected:
void
OnModEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
,
EdgeProcessDirection
direction
)
override
;
// inherit from FuncGraphAnalysis
void
OnAddEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
)
override
;
void
OnDropEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
)
override
;
// subclass can override;
virtual
void
OnModEdge
(
AnfNodePtr
,
int
,
AnfNodePtr
,
EdgeProcessDirection
)
{}
};
};
class
CounterFuncGraphCollector
:
public
DepCollector
{
class
CounterFuncGraphCollector
:
public
DepCollector
{
...
@@ -278,50 +202,27 @@ class CounterFuncGraphCollector : public DepCollector {
...
@@ -278,50 +202,27 @@ class CounterFuncGraphCollector : public DepCollector {
void
ExtraReset
()
override
{
count_func_graphs_map_
.
clear
();
}
void
ExtraReset
()
override
{
count_func_graphs_map_
.
clear
();
}
};
};
class
FuncGraphChildDirect
final
:
public
CounterFuncGraphCollector
{
template
<
typename
ValueT
,
class
CollectorHash
=
std
::
hash
<
ValueT
>,
class
CollectorEqual
=
std
::
equal_to
<
ValueT
>>
public:
class
CounterAnfNodeCollector
:
public
DepCollector
{
explicit
FuncGraphChildDirect
(
const
FuncGraphManager
*
m
)
:
CounterFuncGraphCollector
(
m
)
{}
void
OnMoveAllCNode
(
FuncGraphPtr
src
,
FuncGraphPtr
dst
)
override
;
~
FuncGraphChildDirect
()
override
=
default
;
protected:
void
OnModEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
,
EdgeProcessDirection
direction
)
override
;
};
// graph's all parents, parentsdirect have a map, which key is graph, value is this graph's all direct and proxy
// parents:
// 1.proxy parent: graph g use graph f, key is g, value is ParentProxy(f) because f's parent will be g's parent
// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f
class
FuncGraphParentsDirectCollector
final
:
public
CounterFuncGraphCollector
{
public:
public:
explicit
FuncGraphParentsDirectCollector
(
const
FuncGraphManager
*
m
)
:
CounterFuncGraphCollector
(
m
)
{}
explicit
CounterAnfNodeCollector
(
const
FuncGraphManager
*
m
)
:
DepCollector
(
m
)
{}
~
FuncGraphParentsDirectCollector
()
override
=
default
;
~
CounterAnfNodeCollector
()
override
=
default
;
void
OnMoveAllCNode
(
FuncGraphPtr
src
,
FuncGraphPtr
dst
)
override
;
FuncGraphToAnfNodeCounterMap
<
ValueT
,
CollectorHash
,
CollectorEqual
>
&
count_nodes_map
()
{
return
count_nodes_map_
;
}
protected:
void
OnModEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
,
EdgeProcessDirection
direction
)
override
;
};
// graph's all used graphs: key is g, value is g used graph
size_t
size
()
const
override
{
return
count_nodes_map_
.
size
();
}
class
FuncGraphsUsedCollector
final
:
public
CounterFuncGraphCollector
{
void
OnAddFuncGraph
(
FuncGraphPtr
fg
)
final
{
public:
count_nodes_map_
[
fg
]
=
OrderedMap
<
ValueT
,
int
,
CollectorHash
,
CollectorEqual
>
();
explicit
FuncGraphsUsedCollector
(
const
FuncGraphManager
*
m
)
:
CounterFuncGraphCollector
(
m
)
{}
}
void
OnMoveAllCNode
(
FuncGraphPtr
src
,
FuncGraphPtr
dst
)
override
;
void
OnDropFuncGraph
(
FuncGraphPtr
fg
)
final
{
(
void
)
count_nodes_map_
.
erase
(
fg
);
}
~
FuncGraphsUsedCollector
()
override
=
default
;
protected:
bool
Inc
(
const
FuncGraphPtr
&
func_graph
,
const
ValueT
&
key
,
int
count
);
void
OnModEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
,
EdgeProcessDirection
direction
)
override
;
bool
Dec
(
const
FuncGraphPtr
&
func_graph
,
const
ValueT
&
key
,
int
count
)
;
}
;
bool
Mod
(
const
FuncGraphPtr
&
func_graph
,
const
ValueT
&
key
,
int
count
)
;
class
FuncGraphJDirectCollector
final
:
public
CounterFuncGraphCollector
{
FuncGraphToAnfNodeCounterMap
<
ValueT
,
CollectorHash
,
CollectorEqual
>
count_nodes_map_
;
public:
explicit
FuncGraphJDirectCollector
(
const
FuncGraphManager
*
m
)
:
CounterFuncGraphCollector
(
m
)
{}
void
OnMoveAllCNode
(
FuncGraphPtr
src
,
const
FuncGraphPtr
dst
)
override
;
~
FuncGraphJDirectCollector
()
override
=
default
;
protected:
protected:
void
OnModEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
inp
,
EdgeProcessDirection
direction
)
override
;
void
ExtraReset
()
override
{
count_nodes_map_
.
clear
();
}
};
};
using
FuncGraphToFuncGraphSetMap
=
OrderedMap
<
FuncGraphPtr
,
FuncGraphSet
>
;
using
FuncGraphToFuncGraphSetMap
=
OrderedMap
<
FuncGraphPtr
,
FuncGraphSet
>
;
...
@@ -367,8 +268,8 @@ class DepComputer : public FuncGraphAnalysis {
...
@@ -367,8 +268,8 @@ class DepComputer : public FuncGraphAnalysis {
// graph g's all direct or proxy parents
// graph g's all direct or proxy parents
class
FuncGraphParentsTotalComputer
final
:
public
DepComputer
{
class
FuncGraphParentsTotalComputer
final
:
public
DepComputer
{
public:
public:
explicit
FuncGraphParentsTotalComputer
(
const
FuncGraphManager
*
m
)
:
DepComputer
(
m
)
,
all_parents_direct_
(
nullptr
)
{}
explicit
FuncGraphParentsTotalComputer
(
const
FuncGraphManager
*
m
)
:
DepComputer
(
m
)
{}
~
FuncGraphParentsTotalComputer
()
override
{
all_parents_direct_
=
nullptr
;
}
~
FuncGraphParentsTotalComputer
()
override
=
default
;
FuncGraphToFuncGraphSetMap
&
func_graph_parents_total_analysis
()
{
return
func_graph_parents_total_analysis_
;
}
FuncGraphToFuncGraphSetMap
&
func_graph_parents_total_analysis
()
{
return
func_graph_parents_total_analysis_
;
}
...
@@ -382,10 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer {
...
@@ -382,10 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer {
void
RealRecompute
(
FuncGraphPtr
fg
)
override
;
void
RealRecompute
(
FuncGraphPtr
fg
)
override
;
private:
private:
FuncGraphSetPtr
SeekParents
(
const
FuncGraphPtr
&
fg
,
const
FuncGraphSetPtr
&
path
=
std
::
make_shared
<
FuncGraphSet
>
());
FuncGraphSetPtr
SeekParents
(
const
FuncGraphPtr
&
fg
,
size_t
seen_num
);
// when SeekParents calls itself recursively, it can access these variables by class member
// other than pass by formal parameters, it can save 1 parameter for SeekParents().
FuncGraphToFuncGraphCounterMap
*
all_parents_direct_
;
};
};
using
FuncGraphToFuncGraphMap
=
OrderedMap
<
FuncGraphPtr
,
FuncGraphPtr
>
;
using
FuncGraphToFuncGraphMap
=
OrderedMap
<
FuncGraphPtr
,
FuncGraphPtr
>
;
...
@@ -525,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer {
...
@@ -525,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer {
void
ExtraReset
()
override
{
j_total_analysis_
.
clear
();
}
void
ExtraReset
()
override
{
j_total_analysis_
.
clear
();
}
void
RealRecompute
(
FuncGraphPtr
fg
)
override
;
void
RealRecompute
(
FuncGraphPtr
fg
)
override
;
bool
SeekJ
(
const
FuncGraphPtr
&
fg
,
const
FuncGraphSetPtr
&
path
);
bool
SeekJ
(
const
FuncGraphPtr
&
fg
,
size_t
seen_num
);
};
};
class
FuncGraphManager
:
public
std
::
enable_shared_from_this
<
FuncGraphManager
>
{
class
FuncGraphManager
:
public
std
::
enable_shared_from_this
<
FuncGraphManager
>
{
...
@@ -562,30 +460,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
...
@@ -562,30 +460,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
NodeUsersMap
&
node_users
()
{
return
node_users_
;
}
NodeUsersMap
&
node_users
()
{
return
node_users_
;
}
FuncGraphToAnfNodeMap
&
nodes
()
const
{
return
nodes_
->
nodes_analysis_
;
}
FuncGraphToAnfNodeCounterMap
<
AnfNodePtr
>
&
valuenodes
()
const
{
return
valuenodes_
->
count_nodes_map_
;
}
FuncGraphToAnfNodeCounterMap
<
AnfNodePtr
>
&
free_variables_direct
()
const
{
return
free_variables_direct_
->
count_nodes_map_
;
}
FuncGraphToAnfNodeCounterMap
<
CNodeIndexPairPtr
,
CNodeIndexHasher
,
CNodeIndexEqual
>
&
func_graph_cnodes_index
()
const
{
return
func_graph_cnodes_index_
->
count_nodes_map_
;
}
FuncGraphToFuncGraphCounterMap
&
func_graphs_used
()
const
{
return
func_graphs_used_
->
count_func_graphs_map_
;
}
FuncGraphToFuncGraphCounterMap
&
func_graph_child_direct
()
const
{
return
func_graph_child_direct_
->
count_func_graphs_map_
;
}
FuncGraphToFuncGraphCounterMap
&
func_graph_parents_direct
()
const
{
return
func_graph_parents_direct_
->
count_func_graphs_map_
;
}
FuncGraphToFuncGraphCounterMap
&
func_graph_j_direct
()
const
{
return
func_graph_j_direct_
->
count_func_graphs_map_
;
}
FVTotalMap
&
free_variables_total
()
const
;
FVTotalMap
&
free_variables_total
()
const
;
FuncGraphSet
&
func_graph_parents_total
(
const
FuncGraphPtr
&
fg
)
const
;
FuncGraphSet
&
func_graph_parents_total
(
const
FuncGraphPtr
&
fg
)
const
;
...
@@ -610,14 +484,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
...
@@ -610,14 +484,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
// Static Analysis
// Static Analysis
NodeUsersMap
node_users_
;
NodeUsersMap
node_users_
;
AnfNodeSet
all_nodes_
;
// managed nodes
AnfNodeSet
all_nodes_
;
// managed nodes
std
::
shared_ptr
<
NodesCollector
>
nodes_
;
std
::
shared_ptr
<
ValueNodesCollector
>
valuenodes_
;
std
::
shared_ptr
<
FVDirectCollector
>
free_variables_direct_
;
std
::
shared_ptr
<
FuncGraphUsersCNodeIndexCollector
>
func_graph_cnodes_index_
;
std
::
shared_ptr
<
FuncGraphsUsedCollector
>
func_graphs_used_
;
std
::
shared_ptr
<
FuncGraphChildDirect
>
func_graph_child_direct_
;
std
::
shared_ptr
<
FuncGraphParentsDirectCollector
>
func_graph_parents_direct_
;
std
::
shared_ptr
<
FuncGraphJDirectCollector
>
func_graph_j_direct_
;
// Dynamic Analysis
// Dynamic Analysis
std
::
shared_ptr
<
ParentComputer
>
func_graph_parent_
;
std
::
shared_ptr
<
ParentComputer
>
func_graph_parent_
;
...
@@ -630,6 +496,9 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
...
@@ -630,6 +496,9 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
FuncGraphSetPtr
MaybeDropNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
nodes
);
FuncGraphSetPtr
MaybeDropNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
nodes
);
void
ParseChanges
(
const
std
::
vector
<
Change
>
&
changes
,
EdgeTupleCounter
*
add_edges
,
EdgeTupleCounter
*
rm_edges
,
void
ParseChanges
(
const
std
::
vector
<
Change
>
&
changes
,
EdgeTupleCounter
*
add_edges
,
EdgeTupleCounter
*
rm_edges
,
Counter
<
AnfNodePtr
>
*
adds
,
Counter
<
AnfNodePtr
>
*
rms
);
Counter
<
AnfNodePtr
>
*
adds
,
Counter
<
AnfNodePtr
>
*
rms
);
void
AddEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
input
);
void
DropEdge
(
AnfNodePtr
node
,
int
index
,
AnfNodePtr
input
);
void
MoveAllNodes
(
FuncGraphPtr
source
,
FuncGraphPtr
target
);
FuncGraphSet
roots_
;
// managed roots
FuncGraphSet
roots_
;
// managed roots
FuncGraphSet
func_graphs_
;
// managed func graphs
FuncGraphSet
func_graphs_
;
// managed func graphs
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.cc
浏览文件 @
848d1920
...
@@ -492,7 +492,7 @@ void DFunctor::MapParamObject() {
...
@@ -492,7 +492,7 @@ void DFunctor::MapParamObject() {
void
DFunctor
::
MapValueObject
()
{
void
DFunctor
::
MapValueObject
()
{
// Map ValueNode.
// Map ValueNode.
auto
manager
=
resources_
->
manager
();
auto
manager
=
resources_
->
manager
();
auto
&
value_nodes
=
manager
->
valuenodes
()[
primal_graph_
]
;
auto
&
value_nodes
=
primal_graph_
->
value_nodes
()
;
for
(
const
auto
&
value_pair
:
value_nodes
)
{
for
(
const
auto
&
value_pair
:
value_nodes
)
{
auto
node
=
value_pair
.
first
;
auto
node
=
value_pair
.
first
;
auto
parent_adjoint
=
FindAdjoint
(
node
);
auto
parent_adjoint
=
FindAdjoint
(
node
);
...
...
mindspore/ccsrc/optimizer/irpass/branch_culling.cc
浏览文件 @
848d1920
...
@@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
...
@@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
repl_node
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
repl_node
;
// record the node input to be replaced
// record the node input to be replaced
NodeInputReplMap
repl_node_inputs
;
NodeInputReplMap
repl_node_inputs
;
const
AnfNodeSet
&
nodes
=
manager
->
nodes
()[
graph
]
;
const
AnfNodeSet
&
nodes
=
graph
->
nodes
()
;
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
node
->
isa
<
CNode
>
())
{
if
(
!
node
->
isa
<
CNode
>
())
{
...
@@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode(
...
@@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode(
ResetSharedOp
();
ResetSharedOp
();
std
::
shared_ptr
<
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>>
repl_node
=
std
::
shared_ptr
<
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>>
repl_node
=
std
::
make_shared
<
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>>
();
// record the node to be replaced
std
::
make_shared
<
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>>
();
// record the node to be replaced
const
AnfNodeSet
&
nodes
=
manager
->
nodes
()[
graph
]
;
const
AnfNodeSet
&
nodes
=
graph
->
nodes
()
;
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
node
->
isa
<
CNode
>
())
{
if
(
!
node
->
isa
<
CNode
>
())
{
...
...
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
848d1920
...
@@ -391,7 +391,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
...
@@ -391,7 +391,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
FuncGraphPtr
func_graph
=
res
->
func_graph
();
FuncGraphPtr
func_graph
=
res
->
func_graph
();
auto
manager
=
res
->
manager
();
auto
manager
=
res
->
manager
();
// Remove duplicated value nodes, due to replace operation, can't use reference.
// Remove duplicated value nodes, due to replace operation, can't use reference.
auto
value_nodes
=
manager
->
valuenodes
()[
func_graph
]
;
auto
value_nodes
=
func_graph
->
value_nodes
()
;
HashCache
hash_cache
;
HashCache
hash_cache
;
HashValue
hashes
;
HashValue
hashes
;
for
(
const
auto
&
value_pair
:
value_nodes
)
{
for
(
const
auto
&
value_pair
:
value_nodes
)
{
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
848d1920
...
@@ -488,12 +488,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
...
@@ -488,12 +488,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
void
TraverseGraphMap
(
void
TraverseGraphMap
(
const
FuncGraphManagerPtr
&
manager_ptr
,
FuncGraphTransaction
*
const
tr
,
const
FuncGraphManagerPtr
&
manager_ptr
,
FuncGraphTransaction
*
const
tr
,
const
FuncGraph
ToAnfNodeCounterMap
<
AnfNodePtr
>
&
ct
s
,
const
FuncGraph
Set
&
fg
s
,
const
std
::
function
<
std
::
shared_ptr
<
FuncGraph
>
(
const
PrimitivePtr
,
const
AbstractFunctionPtr
)
>
&
get_prim_graph
)
{
const
std
::
function
<
std
::
shared_ptr
<
FuncGraph
>
(
const
PrimitivePtr
,
const
AbstractFunctionPtr
)
>
&
get_prim_graph
)
{
MS_EXCEPTION_IF_NULL
(
manager_ptr
);
MS_EXCEPTION_IF_NULL
(
manager_ptr
);
MS_EXCEPTION_IF_NULL
(
tr
);
MS_EXCEPTION_IF_NULL
(
tr
);
for
(
const
auto
&
ct_graphs
:
ct
s
)
{
for
(
const
auto
&
fg
:
fg
s
)
{
for
(
const
auto
&
ct_any
:
ct_graphs
.
second
)
{
for
(
const
auto
&
ct_any
:
fg
->
value_nodes
()
)
{
AnfNodePtr
const_primitive_node
=
ct_any
.
first
;
AnfNodePtr
const_primitive_node
=
ct_any
.
first
;
if
(
const_primitive_node
!=
nullptr
&&
IsValueNode
<
Primitive
>
(
const_primitive_node
))
{
if
(
const_primitive_node
!=
nullptr
&&
IsValueNode
<
Primitive
>
(
const_primitive_node
))
{
auto
users
=
manager_ptr
->
node_users
()[
const_primitive_node
];
auto
users
=
manager_ptr
->
node_users
()[
const_primitive_node
];
...
@@ -553,8 +553,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
...
@@ -553,8 +553,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
};
};
FuncGraphTransaction
tr
=
manager_ptr
->
Transact
();
FuncGraphTransaction
tr
=
manager_ptr
->
Transact
();
auto
&
cts
=
manager_ptr
->
valuenode
s
();
auto
&
fgs
=
manager_ptr
->
func_graph
s
();
TraverseGraphMap
(
manager_ptr
,
&
tr
,
ct
s
,
get_prim_graph
);
TraverseGraphMap
(
manager_ptr
,
&
tr
,
fg
s
,
get_prim_graph
);
return
graph
;
return
graph
;
}
}
...
...
tests/ut/cpp/ir/manager_test.cc
浏览文件 @
848d1920
...
@@ -132,18 +132,6 @@ class NestingSpecs {
...
@@ -132,18 +132,6 @@ class NestingSpecs {
CheckAnfNodeCounter
(
counter_p
);
CheckAnfNodeCounter
(
counter_p
);
return
;
return
;
}
}
auto
counter_pair
=
dynamic_pointer_cast
<
CounterAnfNodeCollector
<
CNodeIndexPairPtr
>>
(
results
);
if
(
counter_pair
!=
nullptr
)
{
CheckCNodeIndexPairCounter
(
counter_pair
);
return
;
}
auto
nodes
=
dynamic_pointer_cast
<
NodesCollector
>
(
results
);
if
(
nodes
!=
nullptr
)
{
CheckNodes
(
nodes
);
return
;
}
}
}
private:
private:
...
@@ -205,33 +193,7 @@ class NestingSpecs {
...
@@ -205,33 +193,7 @@ class NestingSpecs {
ASSERT_EQ
(
clean_results
,
expected_
);
ASSERT_EQ
(
clean_results
,
expected_
);
}
}
void
CheckNodes
(
std
::
shared_ptr
<
NodesCollector
>
results
)
{
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
clean_results
;
for
(
auto
&
iter
:
results
->
nodes_analysis
())
{
auto
key
=
iter
.
first
;
auto
value
=
iter
.
second
;
if
(
key
==
nullptr
)
{
continue
;
}
std
::
string
k
=
Name
(
key
);
std
::
set
<
std
::
string
>
v
;
for
(
auto
&
node
:
value
)
{
if
(
!
node
->
isa
<
CNode
>
()
&&
!
Name
(
node
).
empty
())
{
v
.
insert
(
Name
(
node
));
}
}
if
(
!
v
.
empty
())
{
clean_results
[
k
]
=
v
;
}
}
ASSERT_EQ
(
clean_results
,
expected_
);
}
// Add CheckNesting function
// Add CheckNesting function
void
CheckAnfNodeCounter
(
std
::
shared_ptr
<
CounterAnfNodeCollector
<
AnfNodePtr
>>
results
)
{
void
CheckAnfNodeCounter
(
std
::
shared_ptr
<
CounterAnfNodeCollector
<
AnfNodePtr
>>
results
)
{
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
clean_results
;
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
clean_results
;
for
(
auto
&
iter
:
results
->
count_nodes_map
())
{
for
(
auto
&
iter
:
results
->
count_nodes_map
())
{
...
@@ -258,32 +220,6 @@ class NestingSpecs {
...
@@ -258,32 +220,6 @@ class NestingSpecs {
ASSERT_EQ
(
clean_results
,
expected_
);
ASSERT_EQ
(
clean_results
,
expected_
);
}
}
void
CheckCNodeIndexPairCounter
(
std
::
shared_ptr
<
CounterAnfNodeCollector
<
CNodeIndexPairPtr
>>
results
)
{
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
clean_results
;
for
(
auto
&
iter
:
results
->
count_nodes_map
())
{
auto
key
=
iter
.
first
;
auto
value
=
iter
.
second
;
if
(
key
==
nullptr
)
{
continue
;
}
std
::
string
k
=
Name
(
key
);
std
::
set
<
std
::
string
>
v
;
for
(
auto
&
node
:
value
)
{
auto
fg
=
node
.
first
->
first
;
if
(
!
Name
(
fg
).
empty
())
{
v
.
insert
(
Name
(
fg
));
}
}
if
(
!
v
.
empty
())
{
clean_results
[
k
]
=
v
;
}
}
ASSERT_EQ
(
clean_results
,
expected_
);
}
void
CheckGraphCounter
(
std
::
shared_ptr
<
CounterFuncGraphCollector
>
results
)
{
void
CheckGraphCounter
(
std
::
shared_ptr
<
CounterFuncGraphCollector
>
results
)
{
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
clean_results
;
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
clean_results
;
for
(
auto
&
iter
:
results
->
count_func_graphs_map
())
{
for
(
auto
&
iter
:
results
->
count_func_graphs_map
())
{
...
@@ -471,17 +407,10 @@ std::vector<FuncGraphPtr> MakeNestedGraph2() {
...
@@ -471,17 +407,10 @@ std::vector<FuncGraphPtr> MakeNestedGraph2() {
}
}
// Add TestManager::CheckManager function to checkout the result
// Add TestManager::CheckManager function to checkout the result
void
TestManager
::
CheckAnalysisSize
(
std
::
shared_ptr
<
FuncGraphManager
>
mng
)
{
void
TestManager
::
CheckAnalysisSize
(
std
::
shared_ptr
<
FuncGraphManager
>
mng
)
{
auto
size
=
mng
->
func_graphs
().
size
();
auto
size
=
mng
->
func_graphs
().
size
();
ASSERT_EQ
(
size
+
1
,
mng
->
nodes
().
size
());
ASSERT_EQ
(
size
,
mng
->
free_variables_total
().
size
());
ASSERT_EQ
(
size
,
mng
->
free_variables_total
().
size
());
ASSERT_EQ
(
size
,
mng
->
valuenodes
().
size
());
ASSERT_EQ
(
size
,
mng
->
free_variables_direct
().
size
());
ASSERT_EQ
(
size
,
mng
->
func_graph_cnodes_index
().
size
());
ASSERT_EQ
(
size
,
mng
->
func_graph_parents_direct
().
size
());
ASSERT_EQ
(
size
,
mng
->
func_graphs_used
().
size
());
}
}
TEST_F
(
TestManager
,
test_scalar_add_manual
)
{
TEST_F
(
TestManager
,
test_scalar_add_manual
)
{
...
@@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) {
...
@@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ
(
1
,
mng
->
roots
().
size
());
ASSERT_EQ
(
1
,
mng
->
roots
().
size
());
CheckAnalysisSize
(
mng
);
CheckAnalysisSize
(
mng
);
auto
nodes
=
mng
->
nodes
();
ASSERT_EQ
(
2
,
f
->
nodes
().
size
());
ASSERT_EQ
(
3
,
nodes
[
nullptr
].
size
());
ASSERT_EQ
(
1
,
g
->
nodes
().
size
());
ASSERT_EQ
(
2
,
nodes
[
f
].
size
());
ASSERT_EQ
(
1
,
nodes
[
g
].
size
());
auto
users
=
mng
->
node_users
();
auto
users
=
mng
->
node_users
();
for
(
auto
&
iter
:
users
)
{
for
(
auto
&
iter
:
users
)
{
ASSERT_EQ
(
1
,
iter
.
second
.
size
());
ASSERT_EQ
(
1
,
iter
.
second
.
size
());
}
}
auto
graphs_used
=
mng
->
func_graphs_used
();
ASSERT_EQ
(
1
,
f
->
func_graphs_used
().
size
());
ASSERT_EQ
(
1
,
graphs_used
[
f
].
size
());
ASSERT_EQ
(
0
,
g
->
func_graphs_used
().
size
());
ASSERT_EQ
(
0
,
graphs_used
[
g
].
size
());
auto
fv_direct
=
mng
->
free_variables_direct
();
ASSERT_EQ
(
0
,
f
->
free_variables
().
size
());
ASSERT_EQ
(
0
,
fv_direct
[
f
].
size
());
ASSERT_EQ
(
1
,
g
->
free_variables
().
size
());
ASSERT_EQ
(
1
,
fv_direct
[
g
].
size
());
auto
fv_total
=
mng
->
free_variables_total
();
auto
fv_total
=
mng
->
free_variables_total
();
ASSERT_EQ
(
0
,
fv_total
[
f
].
size
());
ASSERT_EQ
(
0
,
fv_total
[
f
].
size
());
ASSERT_EQ
(
1
,
fv_total
[
g
].
size
());
ASSERT_EQ
(
1
,
fv_total
[
g
].
size
());
auto
cnodes
=
mng
->
func_graph_cnodes_index
();
ASSERT_EQ
(
0
,
f
->
func_graph_cnodes_index
().
size
());
ASSERT_EQ
(
0
,
cnodes
[
f
].
size
());
ASSERT_EQ
(
1
,
g
->
func_graph_cnodes_index
().
size
());
ASSERT_EQ
(
1
,
cnodes
[
g
].
size
());
}
}
TEST_F
(
TestManager
,
test_deep_nested2_manual
)
{
TEST_F
(
TestManager
,
test_deep_nested2_manual
)
{
...
@@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) {
...
@@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) {
ASSERT_EQ
(
3
,
mng
->
func_graphs
().
size
());
ASSERT_EQ
(
3
,
mng
->
func_graphs
().
size
());
ASSERT_EQ
(
1
,
mng
->
roots
().
size
());
ASSERT_EQ
(
1
,
mng
->
roots
().
size
());
ASSERT_EQ
(
4
,
mng
->
nodes
().
size
());
ASSERT_EQ
(
4
,
gfn
->
nodes
().
size
());
ASSERT_EQ
(
20
,
mng
->
all_nodes
().
size
());
ASSERT_EQ
(
20
,
mng
->
all_nodes
().
size
());
ASSERT_EQ
(
25
,
mng
->
node_users
().
size
());
ASSERT_EQ
(
25
,
mng
->
node_users
().
size
());
CheckAnalysisSize
(
mng
);
CheckAnalysisSize
(
mng
);
...
@@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) {
...
@@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) {
ASSERT_EQ
(
3
,
mng
->
func_graphs
().
size
());
ASSERT_EQ
(
3
,
mng
->
func_graphs
().
size
());
ASSERT_EQ
(
1
,
mng
->
roots
().
size
());
ASSERT_EQ
(
1
,
mng
->
roots
().
size
());
ASSERT_EQ
(
4
,
mng
->
nodes
().
size
());
ASSERT_EQ
(
20
,
mng
->
all_nodes
().
size
());
ASSERT_EQ
(
20
,
mng
->
all_nodes
().
size
());
CheckAnalysisSize
(
mng
);
CheckAnalysisSize
(
mng
);
}
}
...
@@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) {
...
@@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) {
FuncGraphPtr
fg
=
getPyFun
(
"ir_get_fn"
);
FuncGraphPtr
fg
=
getPyFun
(
"ir_get_fn"
);
auto
mng
=
Manage
(
fg
);
auto
mng
=
Manage
(
fg
);
const
FuncGraphToAnfNodeMap
&
nodes
=
mng
->
node
s
();
const
auto
&
fgs
=
mng
->
func_graph
s
();
ASSERT_TRUE
(
nodes
.
find
(
fg
)
!=
nodes
.
end
(
));
ASSERT_TRUE
(
fgs
.
contains
(
fg
));
FuncGraphSet
s
;
FuncGraphSet
s
;
s
.
add
(
fg
);
s
.
add
(
fg
);
mng
->
MaybeDropFuncGraphs
(
s
);
mng
->
MaybeDropFuncGraphs
(
s
);
ASSERT_TRUE
(
nodes
.
find
(
fg
)
!=
nodes
.
end
(
));
ASSERT_TRUE
(
fgs
.
contains
(
fg
));
}
}
TEST_F
(
TestManager
,
test_keep_roots
)
{
TEST_F
(
TestManager
,
test_keep_roots
)
{
...
...
tests/ut/cpp/optimizer/cconv_test.cc
浏览文件 @
848d1920
...
@@ -26,15 +26,14 @@
...
@@ -26,15 +26,14 @@
namespace
mindspore
{
namespace
mindspore
{
void
CheckNoFreeVariables
(
FuncGraphPtr
root
)
{
void
CheckNoFreeVariables
(
FuncGraphPtr
root
)
{
auto
mng
=
Manage
(
root
);
auto
mng
=
Manage
(
root
);
for
(
auto
&
iter
:
mng
->
nodes
())
{
for
(
auto
&
iter
:
mng
->
func_graphs
())
{
auto
g
=
iter
.
first
;
auto
g
=
iter
;
auto
nodes
=
iter
.
second
;
if
(
g
==
nullptr
)
{
if
(
g
==
nullptr
)
{
continue
;
continue
;
}
}
ASSERT_TRUE
(
g
->
parent
()
==
nullptr
);
ASSERT_TRUE
(
g
->
parent
()
==
nullptr
);
auto
nodes
=
g
->
nodes
();
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
ASSERT_EQ
(
node
->
func_graph
(),
g
);
ASSERT_EQ
(
node
->
func_graph
(),
g
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录