Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
dbb86cb1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dbb86cb1
编写于
5月 24, 2020
作者:
Z
Zhang Qinghua
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adjust some routines of FG and FGM, about the nodes info. IF.
上级
737bfc95
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
113 addition
and
101 deletion
+113
-101
mindspore/ccsrc/ir/func_graph.cc
mindspore/ccsrc/ir/func_graph.cc
+55
-37
mindspore/ccsrc/ir/func_graph.h
mindspore/ccsrc/ir/func_graph.h
+16
-16
mindspore/ccsrc/ir/func_graph_cloner.cc
mindspore/ccsrc/ir/func_graph_cloner.cc
+3
-3
mindspore/ccsrc/ir/manager.cc
mindspore/ccsrc/ir/manager.cc
+36
-42
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+1
-1
tests/ut/cpp/ir/manager_test.cc
tests/ut/cpp/ir/manager_test.cc
+2
-2
未找到文件。
mindspore/ccsrc/ir/func_graph.cc
浏览文件 @
dbb86cb1
...
...
@@ -198,7 +198,7 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
const
AnfNodeSet
&
FuncGraph
::
nodes
()
{
return
nodes_
;
}
void
FuncGraph
::
CopyNodes
(
const
AnfNodeSet
&
other_nodes
)
{
nodes_
=
other_nodes
;
}
void
FuncGraph
::
CopyNodes
(
const
FuncGraphPtr
&
source
)
{
nodes_
=
source
->
nodes
()
;
}
void
FuncGraph
::
ClearNodes
()
{
nodes_
.
clear
();
}
...
...
@@ -215,7 +215,12 @@ void FuncGraph::DropNode(AnfNodePtr node) {
const
AnfNodeCounterMap
&
FuncGraph
::
value_nodes
()
{
return
value_nodes_
;
}
void
FuncGraph
::
CopyValueNodes
(
const
AnfNodeCounterMap
&
other_value_nodes
)
{
value_nodes_
=
other_value_nodes
;
}
void
FuncGraph
::
CopyValueNodes
(
const
FuncGraphPtr
&
source
)
{
auto
&
others
=
source
->
value_nodes
();
for
(
auto
it
=
others
.
begin
();
it
!=
others
.
end
();
it
++
)
{
AddValueNode
(
it
->
first
,
it
->
second
);
}
}
void
FuncGraph
::
ClearValueNodes
()
{
value_nodes_
.
clear
();
}
...
...
@@ -243,9 +248,9 @@ void FuncGraph::DropValueNode(AnfNodePtr node) {
const
AnfNodeCounterMap
&
FuncGraph
::
free_variables
()
{
return
free_variables_
;
}
void
FuncGraph
::
CopyFreeVariables
(
const
AnfNodeCounterMap
&
others
)
{
auto
it
=
others
.
begin
();
for
(;
it
!=
others
.
end
();
it
++
)
{
void
FuncGraph
::
CopyFreeVariables
(
const
FuncGraphPtr
&
source
)
{
auto
&
others
=
source
->
free_variables
();
for
(
auto
it
=
others
.
begin
()
;
it
!=
others
.
end
();
it
++
)
{
if
(
it
->
first
->
func_graph
().
get
()
!=
this
)
{
(
void
)
AddFreeVariable
(
it
->
first
,
it
->
second
);
}
...
...
@@ -313,31 +318,37 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
return
func_graphs
;
}
const
AnfNodeCounterMap
&
FuncGraph
::
func_graph_value_nodes
()
{
return
func_graph_value_nodes
_
;
}
const
FuncGraphCounterMap
&
FuncGraph
::
func_graphs_used
()
{
return
func_graphs_used
_
;
}
void
FuncGraph
::
CopyFuncGraphValueNodes
(
const
AnfNodeCounterMap
&
others
)
{
func_graph_value_nodes_
=
others
;
}
void
FuncGraph
::
CopyFuncGraphsUsed
(
const
FuncGraphPtr
&
source
)
{
auto
&
others
=
source
->
func_graphs_used
();
for
(
auto
it
=
others
.
begin
();
it
!=
others
.
end
();
it
++
)
{
(
void
)
AddFuncGraphUsed
(
it
->
first
,
it
->
second
);
}
func_graphs_used_
.
erase
(
source
);
}
void
FuncGraph
::
ClearFuncGraph
ValueNodes
()
{
func_graph_value_nodes
_
.
clear
();
}
void
FuncGraph
::
ClearFuncGraph
sUsed
()
{
func_graphs_used
_
.
clear
();
}
bool
FuncGraph
::
AddFuncGraph
ValueNode
(
AnfNodePtr
node
,
int
count
)
{
if
(
func_graph
_value_nodes_
.
count
(
node
)
==
0
)
{
func_graph
_value_nodes_
[
node
]
=
count
;
bool
FuncGraph
::
AddFuncGraph
Used
(
FuncGraphPtr
fg
,
int
count
)
{
if
(
func_graph
s_used_
.
count
(
fg
)
==
0
)
{
func_graph
s_used_
[
fg
]
=
count
;
return
true
;
}
else
{
func_graph
_value_nodes_
[
node
]
+=
count
;
func_graph
s_used_
[
fg
]
+=
count
;
return
false
;
}
}
bool
FuncGraph
::
DropFuncGraph
ValueNode
(
AnfNodePtr
node
)
{
if
(
func_graph
_value_nodes_
.
count
(
node
)
!=
0
)
{
if
(
func_graph
_value_nodes_
[
node
]
==
1
)
{
(
void
)
func_graph
_value_nodes_
.
erase
(
node
);
bool
FuncGraph
::
DropFuncGraph
Used
(
FuncGraphPtr
fg
)
{
if
(
func_graph
s_used_
.
count
(
fg
)
!=
0
)
{
if
(
func_graph
s_used_
[
fg
]
==
1
)
{
(
void
)
func_graph
s_used_
.
erase
(
fg
);
return
true
;
}
else
{
func_graph
_value_nodes_
[
node
]
--
;
if
(
func_graph
_value_nodes_
[
node
]
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Count of
value node(FuncGraph) '"
<<
node
func_graph
s_used_
[
fg
]
--
;
if
(
func_graph
s_used_
[
fg
]
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Count of
FuncGraph '"
<<
fg
<<
"' dec from 0. NodeInfo: "
<<
trace
::
GetDebugInfo
(
debug_info
());
}
}
...
...
@@ -354,11 +365,13 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
const
CNodeIndexCounterMap
&
FuncGraph
::
func_graph_cnodes_index
()
{
return
func_graph_cnodes_index_
;
}
void
FuncGraph
::
CopyFuncGraphCNodesIndex
(
const
CNodeIndexCounterMap
&
others
)
{
auto
it
=
others
.
begin
();
for
(;
it
!=
others
.
end
();
it
++
)
{
void
FuncGraph
::
CopyFuncGraphCNodesIndex
(
const
FuncGraphPtr
&
source
)
{
auto
&
others
=
source
->
func_graph_cnodes_index
();
for
(
auto
it
=
others
.
begin
()
;
it
!=
others
.
end
();
it
++
)
{
// Ignore the user graph who may own itself.
if
(
it
->
first
->
first
->
func_graph
().
get
()
!=
this
)
{
auto
fg
=
it
->
first
->
first
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
fg
);
if
(
fg
.
get
()
!=
this
)
{
AddFuncGraphCNodeIndex
(
it
->
first
,
it
->
second
);
}
}
...
...
@@ -388,28 +401,33 @@ void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) {
}
}
const
AnfNodeCounterMap
&
FuncGraph
::
j_func_graph_value_nodes
()
{
return
j_func_graph_value_node
s_
;
}
const
FuncGraphCounterMap
&
FuncGraph
::
j_func_graphs
()
{
return
j_func_graph
s_
;
}
void
FuncGraph
::
CopyJFuncGraphValueNodes
(
const
AnfNodeCounterMap
&
others
)
{
j_func_graph_value_nodes_
=
others
;
}
void
FuncGraph
::
CopyJFuncGraphs
(
const
FuncGraphPtr
&
source
)
{
auto
&
others
=
source
->
j_func_graphs
();
for
(
auto
it
=
others
.
begin
();
it
!=
others
.
end
();
it
++
)
{
AddJFuncGraph
(
it
->
first
,
it
->
second
);
}
}
void
FuncGraph
::
ClearJFuncGraph
ValueNodes
()
{
j_func_graph_value_node
s_
.
clear
();
}
void
FuncGraph
::
ClearJFuncGraph
s
()
{
j_func_graph
s_
.
clear
();
}
void
FuncGraph
::
AddJFuncGraph
ValueNode
(
AnfNodePtr
node
,
int
count
)
{
if
(
j_func_graph
_value_nodes_
.
count
(
node
)
==
0
)
{
j_func_graph
_value_nodes_
[
node
]
=
count
;
void
FuncGraph
::
AddJFuncGraph
(
FuncGraphPtr
fg
,
int
count
)
{
if
(
j_func_graph
s_
.
count
(
fg
)
==
0
)
{
j_func_graph
s_
[
fg
]
=
count
;
}
else
{
j_func_graph
_value_nodes_
[
node
]
+=
count
;
j_func_graph
s_
[
fg
]
+=
count
;
}
}
void
FuncGraph
::
DropJFuncGraph
ValueNode
(
AnfNodePtr
node
)
{
if
(
j_func_graph
_value_nodes_
.
count
(
node
)
!=
0
)
{
if
(
j_func_graph
_value_nodes_
[
node
]
==
1
)
{
(
void
)
j_func_graph
_value_nodes_
.
erase
(
node
);
void
FuncGraph
::
DropJFuncGraph
(
FuncGraphPtr
fg
)
{
if
(
j_func_graph
s_
.
count
(
fg
)
!=
0
)
{
if
(
j_func_graph
s_
[
fg
]
==
1
)
{
(
void
)
j_func_graph
s_
.
erase
(
fg
);
}
else
{
j_func_graph
_value_nodes_
[
node
]
--
;
if
(
j_func_graph
_value_nodes_
[
node
]
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Count of
value node(J FuncGraph) '"
<<
node
j_func_graph
s_
[
fg
]
--
;
if
(
j_func_graph
s_
[
fg
]
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Count of
J FuncGraph '"
<<
fg
<<
"' dec from 0. NodeInfo: "
<<
trace
::
GetDebugInfo
(
debug_info
());
}
}
...
...
mindspore/ccsrc/ir/func_graph.h
浏览文件 @
dbb86cb1
...
...
@@ -189,21 +189,21 @@ class FuncGraph : public FuncGraphBase {
// get all nodes belonging to this func graph
const
AnfNodeSet
&
nodes
();
void
CopyNodes
(
const
AnfNodeSet
&
other_nodes
);
void
CopyNodes
(
const
FuncGraphPtr
&
source
);
void
ClearNodes
();
void
AddNode
(
AnfNodePtr
node
);
void
DropNode
(
AnfNodePtr
node
);
// get all value_nodes belonging to this func graph
const
AnfNodeCounterMap
&
value_nodes
();
void
CopyValueNodes
(
const
AnfNodeCounterMap
&
other_value_nodes
);
void
CopyValueNodes
(
const
FuncGraphPtr
&
source
);
void
ClearValueNodes
();
void
AddValueNode
(
AnfNodePtr
node
,
int
count
=
1
);
void
DropValueNode
(
AnfNodePtr
node
);
// get all free vars directly used in this func graph
const
AnfNodeCounterMap
&
free_variables
();
void
CopyFreeVariables
(
const
AnfNodeCounterMap
&
others
);
void
CopyFreeVariables
(
const
FuncGraphPtr
&
source
);
void
ClearFreeVariables
();
bool
AddFreeVariable
(
AnfNodePtr
node
,
int
count
=
1
);
bool
DropFreeVariable
(
AnfNodePtr
node
);
...
...
@@ -218,25 +218,25 @@ class FuncGraph : public FuncGraphBase {
std
::
vector
<
FuncGraphPtr
>
free_variables_func_graphs
();
// get all value nodes of func graph directly used by this func graph
const
AnfNodeCounterMap
&
func_graph_value_nodes
();
void
CopyFuncGraph
ValueNodes
(
const
AnfNodeCounterMap
&
others
);
void
ClearFuncGraph
ValueNodes
();
bool
AddFuncGraph
ValueNode
(
AnfNodePtr
node
,
int
count
=
1
);
bool
DropFuncGraph
ValueNode
(
AnfNodePtr
node
);
const
FuncGraphCounterMap
&
func_graphs_used
();
void
CopyFuncGraph
sUsed
(
const
FuncGraphPtr
&
source
);
void
ClearFuncGraph
sUsed
();
bool
AddFuncGraph
Used
(
FuncGraphPtr
fg
,
int
count
=
1
);
bool
DropFuncGraph
Used
(
FuncGraphPtr
fg
);
// get all value nodes of J func graph directly used by this func graph
const
AnfNodeCounterMap
&
j_func_graph_value_node
s
();
void
CopyJFuncGraph
ValueNodes
(
const
AnfNodeCounterMap
&
others
);
void
ClearJFuncGraph
ValueNode
s
();
void
AddJFuncGraph
ValueNode
(
AnfNodePtr
node
,
int
count
=
1
);
void
DropJFuncGraph
ValueNode
(
AnfNodePtr
node
);
const
FuncGraphCounterMap
&
j_func_graph
s
();
void
CopyJFuncGraph
s
(
const
FuncGraphPtr
&
source
);
void
ClearJFuncGraphs
();
void
AddJFuncGraph
(
FuncGraphPtr
fg
,
int
count
=
1
);
void
DropJFuncGraph
(
FuncGraphPtr
fg
);
// get all func graphs nested used by this func graph
const
FuncGraphSet
&
func_graphs_used_total
();
// get all user value nodes of this func graph, by CNode and its input's index
const
CNodeIndexCounterMap
&
func_graph_cnodes_index
();
void
CopyFuncGraphCNodesIndex
(
const
CNodeIndexCounterMap
&
other_value_nodes
);
void
CopyFuncGraphCNodesIndex
(
const
FuncGraphPtr
&
source
);
void
ClearFuncGraphCNodesIndex
();
void
AddFuncGraphCNodeIndex
(
CNodeIndexPairPtr
node
,
int
count
=
1
);
void
DropFuncGraphCNodeIndex
(
CNodeIndexPairPtr
node
);
...
...
@@ -311,13 +311,13 @@ class FuncGraph : public FuncGraphBase {
AnfNodeCounterMap
value_nodes_
;
// all func graph value nodes of the function
AnfNodeCounterMap
func_graph_value_nodes
_
;
FuncGraphCounterMap
func_graphs_used
_
;
// all free variables of the function
AnfNodeCounterMap
free_variables_
;
// all value nodes calling J in the function
AnfNodeCounterMap
j_func_graph_value_node
s_
;
FuncGraphCounterMap
j_func_graph
s_
;
// all user value nodes of this func graph, recording by CNode and its input's index
CNodeIndexCounterMap
func_graph_cnodes_index_
;
...
...
mindspore/ccsrc/ir/func_graph_cloner.cc
浏览文件 @
dbb86cb1
...
...
@@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
if
(
!
clone_all_used_graphs_
)
{
return
;
}
auto
&
used
=
func_graph
->
func_graph
_value_nodes
();
for
(
auto
&
fg
_value_node
:
used
)
{
todo_
.
push_back
({
GetValueNode
<
FuncGraphPtr
>
(
fg_value_node
.
first
)
,
nullptr
,
{}});
auto
&
used
=
func_graph
->
func_graph
s_used
();
for
(
auto
&
fg
:
used
)
{
todo_
.
push_back
({
fg
.
first
,
nullptr
,
{}});
}
}
...
...
mindspore/ccsrc/ir/manager.cc
浏览文件 @
dbb86cb1
...
...
@@ -196,7 +196,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) {
return
;
}
AddIntoManaged
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
signals_
);
std
::
vector
<
AnfNodePtr
>
para
=
func_graph
->
parameters
();
AcquireNodes
(
para
);
std
::
vector
<
AnfNodePtr
>
return_vec
({
func_graph
->
get_return
()});
...
...
@@ -301,7 +300,6 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
std
::
vector
<
AnfNodePtr
>
return_vec
=
{
func_graph
->
get_return
()};
todo
.
update
(
MaybeDropNodes
(
return_vec
));
}
MS_EXCEPTION_IF_NULL
(
signals_
);
for
(
auto
&
fg
:
dropped
)
{
MS_EXCEPTION_IF_NULL
(
fg
);
all_nodes_
.
difference_update
(
fg
->
parameters
());
...
...
@@ -334,7 +332,6 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
}
auto
&
users_node
=
node_users_
[
inp
];
users_node
.
add
(
make_pair
(
node
,
index
));
MS_EXCEPTION_IF_NULL
(
signals_
);
AddEdge
(
node
,
index
,
inp
);
}
}
...
...
@@ -384,8 +381,6 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
FuncGraphSetPtr
FuncGraphManager
::
MaybeDropNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
nodes
)
{
AnfNodeSet
nodes_ordered
(
nodes
);
FuncGraphSetPtr
func_graphs_to_check
=
std
::
make_shared
<
FuncGraphSet
>
();
MS_EXCEPTION_IF_NULL
(
signals_
);
while
(
!
nodes_ordered
.
empty
())
{
AnfNodePtr
node
=
nodes_ordered
.
pop
();
MS_EXCEPTION_IF_NULL
(
node
);
...
...
@@ -475,13 +470,13 @@ inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr inp
if
(
input
->
isa
<
ValueNode
>
())
{
fg
->
AddValueNode
(
input
);
if
(
IsValueNode
<
FuncGraph
>
(
input
))
{
if
(
fg
->
AddFuncGraphValueNode
(
input
))
{
signals_
->
InvalidateComputer
();
}
auto
used
=
GetValueNode
<
FuncGraphPtr
>
(
input
);
used
->
AddFuncGraphCNodeIndex
(
std
::
make_shared
<
CNodeIndexPair
>
(
std
::
make_pair
(
node
,
index
)));
if
(
fg
->
AddFuncGraphUsed
(
used
))
{
signals_
->
InvalidateComputer
();
}
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimJ
))
{
fg
->
AddJFuncGraph
ValueNode
(
input
);
fg
->
AddJFuncGraph
(
used
);
}
}
}
else
if
(
fg
!=
nullptr
&&
fg
!=
input
->
func_graph
())
{
...
...
@@ -496,13 +491,13 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in
if
(
input
->
isa
<
ValueNode
>
())
{
fg
->
DropValueNode
(
input
);
if
(
IsValueNode
<
FuncGraph
>
(
input
))
{
if
(
fg
->
DropFuncGraphValueNode
(
input
))
{
signals_
->
InvalidateComputer
();
}
auto
used
=
GetValueNode
<
FuncGraphPtr
>
(
input
);
used
->
DropFuncGraphCNodeIndex
(
std
::
make_shared
<
CNodeIndexPair
>
(
std
::
make_pair
(
node
,
index
)));
if
(
fg
->
DropFuncGraphUsed
(
used
))
{
signals_
->
InvalidateComputer
();
}
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimJ
))
{
fg
->
DropJFuncGraph
ValueNode
(
input
);
fg
->
DropJFuncGraph
(
used
);
}
}
}
else
if
(
fg
!=
nullptr
&&
fg
!=
input
->
func_graph
())
{
...
...
@@ -513,19 +508,19 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in
}
inline
void
FuncGraphManager
::
MoveAllNodes
(
FuncGraphPtr
source
,
FuncGraphPtr
target
)
{
target
->
CopyNodes
(
source
->
nodes
()
);
target
->
CopyValueNodes
(
source
->
value_nodes
()
);
target
->
CopyFuncGraphCNodesIndex
(
source
->
func_graph_cnodes_index
()
);
target
->
CopyFreeVariables
(
source
->
free_variables
()
);
target
->
CopyFuncGraph
ValueNodes
(
source
->
func_graph_value_nodes
()
);
target
->
CopyJFuncGraph
ValueNodes
(
source
->
j_func_graph_value_nodes
()
);
target
->
CopyNodes
(
source
);
target
->
CopyValueNodes
(
source
);
target
->
CopyFuncGraphCNodesIndex
(
source
);
target
->
CopyFreeVariables
(
source
);
target
->
CopyFuncGraph
sUsed
(
source
);
target
->
CopyJFuncGraph
s
(
source
);
signals_
->
InvalidateComputer
();
source
->
ClearNodes
();
source
->
ClearValueNodes
();
source
->
ClearFuncGraphCNodesIndex
();
source
->
ClearFreeVariables
();
source
->
ClearFuncGraph
ValueNodes
();
source
->
ClearJFuncGraph
ValueNode
s
();
source
->
ClearFuncGraph
sUsed
();
source
->
ClearJFuncGraphs
();
}
FuncGraphTransaction
FuncGraphManager
::
Transact
()
{
...
...
@@ -768,10 +763,10 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
}
// Search the fv in fg's child func graph.
auto
&
fg
_value_nodes
=
fg
->
func_graph_value_nodes
();
for
(
auto
&
fg_value_node
:
fg_value_node
s
)
{
auto
&
fg
s
=
fg
->
func_graphs_used
();
for
(
auto
&
item
:
fg
s
)
{
fg
->
seen_
=
seen_num
;
auto
gt
=
GetValueNode
<
FuncGraphPtr
>
(
fg_value_node
.
first
)
;
auto
gt
=
item
.
first
;
parents
->
update
(
SeekParents
(
gt
,
seen_num
));
}
(
void
)
parents
->
erase
(
fg
);
...
...
@@ -865,15 +860,15 @@ void FVTotalComputer::RealRecompute() {
}
}
auto
&
used
=
fg
->
func_graph
_value_nodes
();
auto
&
used
=
fg
->
func_graph
s_used
();
for
(
auto
&
iter
:
used
)
{
auto
p
=
manager
->
parent
(
GetValueNode
<
FuncGraphPtr
>
(
iter
.
first
)
);
auto
p
=
manager
->
parent
(
iter
.
first
);
if
(
p
==
nullptr
)
{
continue
;
}
auto
curr
=
fg
;
while
(
curr
!=
p
)
{
(
void
)
CounterFuncGraphCollector
::
Mod
(
curr
,
GetValueNode
<
FuncGraphPtr
>
(
iter
.
first
)
,
iter
.
second
);
(
void
)
CounterFuncGraphCollector
::
Mod
(
curr
,
iter
.
first
,
iter
.
second
);
curr
=
manager
->
parent
(
curr
);
}
}
...
...
@@ -899,8 +894,8 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
while
(
!
todo
.
empty
())
{
todo_new
.
clear
();
for
(
auto
&
gt
:
todo
)
{
for
(
auto
&
item
:
gt
->
func_graph
_value_nodes
())
{
auto
used_fg
=
GetValueNode
<
FuncGraphPtr
>
(
item
.
first
)
;
for
(
auto
&
item
:
gt
->
func_graph
s_used
())
{
auto
used_fg
=
item
.
first
;
if
(
used_fg
==
fg
)
{
func_graph_used_total_analysis_
[
fg
].
add
(
used_fg
);
continue
;
...
...
@@ -925,8 +920,8 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f
while
(
!
todo
.
empty
())
{
todo_new
.
clear
();
for
(
auto
&
gt
:
todo
)
{
for
(
auto
&
item
:
gt
->
func_graph
_value_nodes
())
{
auto
used_g
=
GetValueNode
<
FuncGraphPtr
>
(
item
.
first
)
;
for
(
auto
&
item
:
gt
->
func_graph
s_used
())
{
auto
used_g
=
item
.
first
;
if
(
used_g
==
fg
)
{
return
true
;
}
...
...
@@ -957,9 +952,9 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
}
}
else
{
trace
->
push_back
(
fg
);
auto
&
items
=
fg
->
func_graph
_value_nodes
();
auto
&
items
=
fg
->
func_graph
s_used
();
for
(
auto
iter
=
items
.
begin
();
iter
!=
items
.
end
();
(
void
)
iter
++
)
{
CheckRecursiveGraphs
(
GetValueNode
<
FuncGraphPtr
>
(
iter
->
first
)
,
trace
);
CheckRecursiveGraphs
(
iter
->
first
,
trace
);
}
trace
->
pop_back
();
if
(
!
recursive_map_
.
count
(
fg
))
{
...
...
@@ -973,14 +968,13 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
MS_LOG
(
DEBUG
)
<<
fg
->
ToString
()
<<
" had been checked"
;
return
false
;
}
auto
&
j_fg
_value_nodes
=
fg
->
j_func_graph_value_node
s
();
if
(
!
j_fg
_value_node
s
.
empty
())
{
auto
&
j_fg
s
=
fg
->
j_func_graph
s
();
if
(
!
j_fgs
.
empty
())
{
// check g1->J(fg)->g2->g cycle;
auto
contains_j
=
std
::
find_if
(
j_fg_value_nodes
.
begin
(),
j_fg_value_nodes
.
end
(),
[
seen_num
](
const
std
::
pair
<
AnfNodePtr
,
int
>
iter
)
{
return
GetValueNode
<
FuncGraphPtr
>
(
iter
.
first
)
->
seen_
!=
seen_num
;
});
if
(
contains_j
!=
j_fg_value_nodes
.
end
())
{
auto
contains_j
=
std
::
find_if
(
j_fgs
.
begin
(),
j_fgs
.
end
(),
[
seen_num
](
const
std
::
pair
<
FuncGraphPtr
,
int
>
iter
)
{
return
iter
.
first
->
seen_
!=
seen_num
;
});
if
(
contains_j
!=
j_fgs
.
end
())
{
MS_LOG
(
DEBUG
)
<<
fg
->
ToString
()
<<
" contains J("
<<
contains_j
->
first
->
ToString
()
<<
")"
;
return
true
;
}
...
...
@@ -988,8 +982,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
fg
->
seen_
=
seen_num
;
// check if func graphs used contains J(func_graph);
for
(
auto
&
item
:
fg
->
func_graph
_value_nodes
())
{
auto
used_g
=
GetValueNode
<
FuncGraphPtr
>
(
item
.
first
)
;
for
(
auto
&
item
:
fg
->
func_graph
s_used
())
{
auto
used_g
=
item
.
first
;
if
(
SeekJ
(
used_g
,
seen_num
))
{
MS_LOG
(
DEBUG
)
<<
fg
->
ToString
()
<<
" users func graph "
<<
used_g
->
ToString
()
<<
" which contains J(func_graph)"
;
return
true
;
...
...
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
dbb86cb1
...
...
@@ -2187,7 +2187,7 @@ void MarkForwardCNode(const FuncGraphPtr &root) {
SetForwardFlag
(
all_nodes
);
}
else
{
for
(
auto
&
func_graph
:
graph_set
)
{
MS_LOG
(
INFO
)
<<
"The sub graph size of root is "
<<
root
->
func_graph
_value_nodes
().
size
();
MS_LOG
(
INFO
)
<<
"The sub graph size of root is "
<<
root
->
func_graph
s_used
().
size
();
auto
return_node
=
func_graph
->
get_return
();
MS_EXCEPTION_IF_NULL
(
return_node
);
auto
all_dfs_nodes
=
DeepLinkedGraphSearch
(
return_node
);
...
...
tests/ut/cpp/ir/manager_test.cc
浏览文件 @
dbb86cb1
...
...
@@ -462,8 +462,8 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ
(
1
,
iter
.
second
.
size
());
}
ASSERT_EQ
(
1
,
f
->
func_graph
_value_nodes
().
size
());
ASSERT_EQ
(
0
,
g
->
func_graph
_value_nodes
().
size
());
ASSERT_EQ
(
1
,
f
->
func_graph
s_used
().
size
());
ASSERT_EQ
(
0
,
g
->
func_graph
s_used
().
size
());
ASSERT_EQ
(
0
,
f
->
free_variables
().
size
());
ASSERT_EQ
(
1
,
g
->
free_variables
().
size
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录