Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5aae0d91
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5aae0d91
编写于
5月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1459 Insert assign nodes for linking sub graph
Merge pull request !1459 from zhoufeng/link-assign
上级
7f80d028
f868a285
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
542 addition
and
264 deletion
+542
-264
mindspore/ccsrc/device/ascend/ascend_label_assign.cc
mindspore/ccsrc/device/ascend/ascend_label_assign.cc
+17
-2
mindspore/ccsrc/device/kernel_runtime.h
mindspore/ccsrc/device/kernel_runtime.h
+1
-1
mindspore/ccsrc/session/ascend_control_parser.cc
mindspore/ccsrc/session/ascend_control_parser.cc
+281
-87
mindspore/ccsrc/session/ascend_control_parser.h
mindspore/ccsrc/session/ascend_control_parser.h
+24
-10
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+93
-39
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+4
-4
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+80
-72
mindspore/ccsrc/session/kernel_graph.h
mindspore/ccsrc/session/kernel_graph.h
+7
-4
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+34
-43
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+1
-2
未找到文件。
mindspore/ccsrc/device/ascend/ascend_label_assign.cc
浏览文件 @
5aae0d91
...
...
@@ -28,6 +28,9 @@ namespace device {
namespace
ascend
{
static
void
UpdateLabelGoto
(
NotNull
<
CNodePtr
>
node
)
{
if
(
AnfAlgo
::
HasNodeAttr
(
kAttrLabelIndex
,
node
))
{
return
;
}
if
(
node
->
size
()
<=
kLabelGotoLabelId
)
{
MS_LOG
(
EXCEPTION
)
<<
"Node "
<<
node
->
DebugString
()
<<
" has invalid input size "
<<
node
->
size
();
}
...
...
@@ -42,6 +45,9 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) {
}
static
void
UpdateLabelSwitch
(
NotNull
<
CNodePtr
>
node
)
{
if
(
AnfAlgo
::
HasNodeAttr
(
kAttrLabelIndex
,
node
))
{
return
;
}
if
(
node
->
size
()
<=
kLabelGotoLabelId
)
{
MS_LOG
(
EXCEPTION
)
<<
"Node "
<<
node
->
DebugString
()
<<
" has invalid input size "
<<
node
->
size
();
}
...
...
@@ -69,9 +75,12 @@ static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>
if
(
memo
->
find
(
graph
.
get
())
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
graph
.
get
());
MS_LOG
(
INFO
)
<<
"Assign label for "
<<
graph
->
ToString
();
auto
nodes
=
TopoSort
(
graph
->
get_return
());
graph
->
SetExecOrderByDefault
();
auto
nodes
=
graph
->
execution_order
();
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
...
...
@@ -97,9 +106,15 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap
if
(
memo
->
find
(
graph
.
get
())
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
graph
.
get
());
MS_LOG
(
INFO
)
<<
"Process label goto/switch for "
<<
graph
->
ToString
();
auto
nodes
=
TopoSort
(
graph
->
get_return
());
graph
->
SetExecOrderByDefault
();
auto
nodes
=
graph
->
execution_order
();
auto
end_goto
=
graph
->
get_end_goto
();
if
(
end_goto
!=
nullptr
)
{
nodes
.
push_back
(
end_goto
);
}
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
...
...
mindspore/ccsrc/device/kernel_runtime.h
浏览文件 @
5aae0d91
...
...
@@ -53,6 +53,7 @@ class KernelRuntime {
virtual
bool
GenTask
(
const
session
::
KernelGraph
*
graph
);
bool
LaunchKernel
(
const
session
::
KernelGraph
*
graph
);
virtual
void
AssignStaticMemoryInput
(
const
session
::
KernelGraph
*
graph
);
virtual
void
AssignStaticMemoryValueNode
(
session
::
KernelGraph
*
graph
);
#ifdef ENABLE_DUMP_E2E
DumpConfPtr
GetDumpConf
();
...
...
@@ -67,7 +68,6 @@ class KernelRuntime {
TypeId
type_id
)
=
0
;
virtual
bool
SyncStream
()
=
0
;
void
AssignStaticMemory
(
session
::
KernelGraph
*
graph
);
void
AssignStaticMemoryValueNode
(
session
::
KernelGraph
*
graph
);
void
AssignDynamicMemory
(
session
::
KernelGraph
*
graph
);
void
ReuseAssignDynamicMemory
(
session
::
KernelGraph
*
graph
);
void
AssignNodeOutputMem
(
int
flag
,
const
AnfNodePtr
&
node
,
int
index
);
...
...
mindspore/ccsrc/session/ascend_control_parser.cc
浏览文件 @
5aae0d91
...
...
@@ -22,49 +22,78 @@
namespace
mindspore
{
namespace
session
{
static
VectorRef
GetCallArgs
(
std
::
vector
<
AnfNodePtr
>::
iterator
iter_begin
,
std
::
vector
<
AnfNodePtr
>::
iterator
iter_end
)
{
VectorRef
call_args
;
for
(
auto
iter
=
iter_begin
;
iter
!=
iter_end
;
++
iter
)
{
if
(
utils
::
isa
<
ValueNode
>
(
*
iter
))
{
call_args
.
push_back
(
GetValueNode
(
*
iter
));
}
else
{
call_args
.
push_back
(
*
iter
);
void
AscendControlParser
::
ChildGraphDataAssign
(
const
std
::
map
<
uint32_t
,
KernelGraphPtr
>
&
graph_id_map
)
{
for
(
auto
&
iter
:
graph_id_map
)
{
auto
&
kg
=
iter
.
second
;
MS_EXCEPTION_IF_NULL
(
kg
);
auto
real_inputs
=
kg
->
real_inputs
();
for
(
auto
&
it
:
real_inputs
)
{
auto
&
parameter
=
it
.
first
;
auto
&
args
=
it
.
second
;
for
(
auto
&
arg
:
args
)
{
MS_EXCEPTION_IF_NULL
(
arg
);
if
(
arg
->
isa
<
Parameter
>
())
{
MS_LOG
(
INFO
)
<<
"Parameter should be reused, no need insert assign, parameter: "
<<
parameter
->
DebugString
()
<<
", arg:"
<<
arg
->
DebugString
();
continue
;
}
auto
target_graph_iter
=
graph_id_map
.
find
(
AnfAlgo
::
GetGraphId
(
arg
.
get
()));
if
(
target_graph_iter
==
graph_id_map
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Graph id "
<<
AnfAlgo
::
GetGraphId
(
arg
.
get
())
<<
" not found."
;
}
InsertAssignToGraph
(
NOT_NULL
(
target_graph_iter
->
second
),
NOT_NULL
(
arg
),
NOT_NULL
(
parameter
));
}
}
}
return
call_args
;
}
void
AscendControlParser
::
LinkGraph
(
NotNull
<
KernelGraphPtr
>
kg
)
{
std
::
set
<
KernelGraphPtr
>
memo
;
ProcessKernelGraph
(
kg
,
nullptr
,
nullptr
,
{},
NOT_NULL
(
&
memo
));
ProcessKernelGraph
(
kg
,
nullptr
,
nullptr
,
NOT_NULL
(
&
memo
));
std
::
map
<
uint32_t
,
KernelGraphPtr
>
graph_id_map
;
for
(
auto
&
g
:
memo
)
{
if
(
graph_id_map
.
find
(
g
->
graph_id
())
!=
graph_id_map
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Two graph has same graph id "
<<
g
->
graph_id
()
<<
", graph: "
<<
graph_id_map
[
g
->
graph_id
()]
->
ToString
()
<<
" "
<<
g
->
ToString
();
}
graph_id_map
[
g
->
graph_id
()]
=
g
;
}
ChildGraphDataAssign
(
graph_id_map
);
}
CNodePtr
AscendControlParser
::
GetNextRealKernel
(
std
::
vector
<
CNodePtr
>
list
,
size_t
start
)
{
for
(
size_t
i
=
start
;
i
<
list
.
size
()
-
1
;
++
i
)
{
if
(
!
IsPrimitiveCNode
(
list
[
i
],
prim
::
kPrimPartial
)
&&
AnfAlgo
::
IsRealKernel
(
list
[
i
]))
{
return
list
[
i
];
}
}
return
nullptr
;
}
NotNull
<
CNodePtr
>
AscendControlParser
::
ProcessKernelGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
last_node
,
const
CNodePtr
&
last_label
,
const
VectorRef
&
args
,
const
CNodePtr
&
last_label
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
MS_LOG
(
INFO
)
<<
"Start process KernelGraph "
<<
kg
->
ToString
();
// 0. recursive condition
// 1. recursive condition
if
(
memo
->
find
(
kg
)
!=
memo
->
end
())
{
MS_LOG
(
INFO
)
<<
"KernelGraph has beed processed: "
<<
kg
->
ToString
();
return
NOT_NULL
(
kg
->
get_start_label
());
}
memo
->
insert
(
kg
.
get
());
// 2. args replace placeholder
LinkParentGraph
(
kg
,
last_node
,
last_label
,
args
);
LinkParentGraph
(
kg
,
last_node
,
last_label
,
memo
);
// 3. topological sort
std
::
vector
<
CNodePtr
>
nodes
=
GetCNodes
(
TopoSort
(
kg
->
get_return
()));
kg
->
SetExecOrderByDefault
();
std
::
vector
<
CNodePtr
>
nodes
=
kg
->
execution_order
();
if
(
nodes
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"KernelGraph "
<<
kg
->
ToString
()
<<
" has no cnodes!"
;
}
// 4. insert first_label
auto
start_label
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSetOpName
))});
for
(
auto
node
:
nodes
)
{
if
(
!
IsPrimitiveCNode
(
node
,
prim
::
kPrimPartial
))
{
InsertControlDependToGraph
(
kg
,
NOT_NULL
(
start_label
),
NOT_NULL
(
node
));
break
;
}
}
MS_LOG
(
INFO
)
<<
"Insert start label "
<<
start_label
->
DebugString
()
<<
" to "
<<
kg
->
ToString
();
kg
->
set_start_label
(
start_label
);
// 5. traverse
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
...
...
@@ -79,17 +108,19 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
}
AnfNodePtr
arg
=
cnode
->
input
(
kCNodeCallArg
);
if
(
IsValueNode
<
KernelGraph
>
(
arg
))
{
RecurseCall
(
kg
,
NOT_NULL
(
cnode
),
(
i
+
1
<
nodes
.
size
()
?
nodes
[
i
+
1
]
:
nullptr
),
memo
);
RecurseCall
(
kg
,
NOT_NULL
(
cnode
),
GetNextRealKernel
(
nodes
,
i
+
1
),
memo
);
}
else
if
(
!
arg
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Unknown type call node "
<<
cnode
->
DebugString
();
}
else
if
(
IsPrimitiveCNode
(
arg
->
cast
<
CNodePtr
>
(),
prim
::
kPrimSwitch
))
{
auto
arg_cnode
=
arg
->
cast
<
CNodePtr
>
();
cnode
->
set_inputs
(
cnode
->
inputs
());
RecurseSwitch
(
kg
,
NOT_NULL
(
cnode
),
memo
);
MS_EXCEPTION_IF_NULL
(
arg_cnode
);
cnode
->
set_inputs
(
arg_cnode
->
inputs
());
RecurseSwitch
(
kg
,
NOT_NULL
(
cnode
),
GetNextRealKernel
(
nodes
,
i
+
1
),
memo
);
}
else
if
(
IsPrimitiveCNode
(
arg
->
cast
<
CNodePtr
>
(),
prim
::
kPrimSwitchLayer
))
{
auto
arg_cnode
=
arg
->
cast
<
CNodePtr
>
();
cnode
->
set_inputs
(
cnode
->
inputs
());
RecurseSwitchLayer
(
kg
,
NOT_NULL
(
cnode
),
memo
);
MS_EXCEPTION_IF_NULL
(
arg_cnode
);
cnode
->
set_inputs
(
arg_cnode
->
inputs
());
RecurseSwitchLayer
(
kg
,
NOT_NULL
(
cnode
),
GetNextRealKernel
(
nodes
,
i
+
1
),
memo
);
}
}
...
...
@@ -97,16 +128,6 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
return
NOT_NULL
(
start_label
);
}
std
::
vector
<
CNodePtr
>
AscendControlParser
::
GetCNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
in
)
{
std
::
vector
<
CNodePtr
>
out
;
for
(
auto
&
node
:
in
)
{
if
(
node
->
isa
<
CNode
>
())
{
out
.
push_back
(
node
->
cast
<
CNodePtr
>
());
}
}
return
out
;
}
void
AscendControlParser
::
InsertDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
attch_node
)
{
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
"depend"
))};
auto
return_node
=
kg
->
get_return
();
...
...
@@ -128,11 +149,7 @@ void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg,
}
void
AscendControlParser
::
LinkParentGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
from_graph_call_node
,
const
CNodePtr
&
last_label
,
const
VectorRef
&
args
)
{
if
(
from_graph_call_node
!=
nullptr
)
{
SetSubGraphInput
(
kg
,
NOT_NULL
(
from_graph_call_node
),
args
);
}
const
CNodePtr
&
last_label
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
auto
origin_return
=
kg
->
get_return
();
std
::
vector
<
AnfNodePtr
>
origin_return_inputs
=
origin_return
->
inputs
();
// if entry graph, replace return with make_tuple
...
...
@@ -146,7 +163,8 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
// else replace return with label_goto
auto
label_goto
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelGotoOpName
)),
last_label
});
InsertDependToGraph
(
kg
,
NOT_NULL
(
label_goto
));
MS_LOG
(
INFO
)
<<
"Insert end goto "
<<
label_goto
->
DebugString
()
<<
" to "
<<
kg
->
ToString
();
kg
->
set_end_goto
(
label_goto
);
}
}
...
...
@@ -157,13 +175,14 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
// 1 get kernel graph
auto
origin_inputs
=
cur_node
->
inputs
();
std
::
vector
<
AnfNodePtr
>
new_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelGotoOpName
))};
auto
call_args
=
GetCallArgs
(
origin_inputs
.
begin
()
+
1
,
origin_inputs
.
end
());
if
(
!
IsValueNode
<
KernelGraph
>
(
origin_inputs
[
kCNodeCallArg
]))
{
MS_LOG
(
WARNING
)
<<
"Node "
<<
cur_node
->
DebugString
(
10
)
<<
" index "
<<
kCNodeCallArg
<<
" is not a ValueNode"
;
return
;
}
// 2 return label
auto
back_label
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSetOpName
))});
MS_LOG
(
INFO
)
<<
"Insert back label "
<<
back_label
->
DebugString
()
<<
" to "
<<
kg
->
ToString
()
<<
" call node "
<<
cur_node
->
DebugString
();
// 3 add depend relationship
InsertControlDependToGraph
(
kg
,
cur_node
,
NOT_NULL
(
back_label
));
if
(
next_node
!=
nullptr
&&
next_node
!=
kg
->
get_return
())
{
...
...
@@ -173,7 +192,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
// 4 modify call op to goto op
cur_node
->
set_input
(
kCNodePrim
,
new_inputs
[
kCNodePrim
]);
// 5 recurse sub graph
CNodePtr
sub_label
=
ProcessKernelGraph
(
NOT_NULL
(
call_kg
),
cur_node
,
back_label
,
call_args
,
memo
);
CNodePtr
sub_label
=
ProcessKernelGraph
(
NOT_NULL
(
call_kg
),
cur_node
,
back_label
,
memo
);
new_inputs
.
push_back
(
sub_label
);
new_inputs
.
insert
(
new_inputs
.
end
(),
origin_inputs
.
begin
(),
origin_inputs
.
end
());
cur_node
->
set_inputs
(
new_inputs
);
...
...
@@ -182,32 +201,37 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
}
void
AscendControlParser
::
RecurseSwitch
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
const
CNodePtr
&
next_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
MS_LOG
(
INFO
)
<<
"process switch node "
<<
cur_node
->
DebugString
();
if
(
cur_node
->
size
()
<
kCNodeSwitchLength
)
{
MS_LOG
(
EXCEPTION
)
<<
"Inputs of apply node must more than "
<<
kCNodeSwitchLength
;
}
// 1 return label
auto
back_label
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
prim
::
kPrimLabelSet
)});
// 2 recurse sub graph
auto
back_label
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSetOpName
))});
MS_LOG
(
INFO
)
<<
"Insert back label "
<<
back_label
->
DebugString
()
<<
" to "
<<
kg
->
ToString
()
<<
" switch node "
<<
cur_node
->
DebugString
();
// 2 add depend relationship
InsertControlDependToGraph
(
kg
,
cur_node
,
NOT_NULL
(
back_label
));
if
(
next_node
!=
nullptr
&&
next_node
!=
kg
->
get_return
())
{
InsertControlDependToGraph
(
kg
,
NOT_NULL
(
back_label
),
NOT_NULL
(
next_node
));
}
// 3 recurse sub graph
auto
origin_switch_inputs
=
cur_node
->
inputs
();
std
::
vector
<
AnfNodePtr
>
new_switch_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSwitchOpName
)),
origin_switch_inputs
[
kCNodeSwitchCond
]};
for
(
size_t
i
=
kCNodeSwitchCond
+
1
;
i
<
kCNodeSwitchLength
;
++
i
)
{
//
2
.1 branch kernel graph and args
//
3
.1 branch kernel graph and args
CNodePtr
partial
;
KernelGraphPtr
branch_fg
;
VectorRef
call_args
;
std
::
tie
(
partial
,
branch_fg
,
call_args
)
=
ParsePartial
(
NOT_NULL
(
origin_switch_inputs
[
i
]));
// 2.2 add depend relationship
InsertControlDependToGraph
(
kg
,
cur_node
,
NOT_NULL
(
back_label
));
// 2.3 recurse sub graph
CNodePtr
branch_label
=
ProcessKernelGraph
(
NOT_NULL
(
branch_fg
),
cur_node
,
back_label
,
call_args
,
memo
);
std
::
tie
(
partial
,
branch_fg
)
=
ParsePartial
(
NOT_NULL
(
origin_switch_inputs
[
i
]));
// 3.2 recurse sub graph
CNodePtr
branch_label
=
ProcessKernelGraph
(
NOT_NULL
(
branch_fg
),
cur_node
,
back_label
,
memo
);
new_switch_inputs
.
push_back
(
branch_label
);
}
std
::
swap
(
new_switch_inputs
[
kCNodeSwitchTrue
],
new_switch_inputs
[
kCNodeSwitchFalse
]);
new_switch_inputs
.
insert
(
new_switch_inputs
.
end
(),
origin_switch_inputs
.
begin
(),
origin_switch_inputs
.
end
());
cur_node
->
set_inputs
(
new_switch_inputs
);
cur_node
->
set_abstract
(
nullptr
);
...
...
@@ -215,7 +239,7 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
}
void
AscendControlParser
::
RecurseSwitchLayer
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
const
CNodePtr
&
next_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
MS_LOG
(
INFO
)
<<
"process switch node "
<<
cur_node
->
DebugString
();
if
(
cur_node
->
size
()
<
kCNodeSwitchLayerLength
)
{
...
...
@@ -229,21 +253,24 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
}
auto
branch_partial
=
utils
::
cast
<
CNodePtr
>
(
branch_tuple
)
->
inputs
();
// 1 return label
auto
back_label
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSwitchOpName
))});
// 2 recurse sub graph
auto
back_label
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSetOpName
))});
// 2 add depend relationship
InsertControlDependToGraph
(
kg
,
cur_node
,
NOT_NULL
(
back_label
));
if
(
next_node
!=
nullptr
&&
next_node
!=
kg
->
get_return
())
{
InsertControlDependToGraph
(
kg
,
NOT_NULL
(
back_label
),
NOT_NULL
(
next_node
));
}
// 3 recurse sub graph
auto
origin_switch_inputs
=
cur_node
->
inputs
();
std
::
vector
<
AnfNodePtr
>
new_switch_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
prim
::
kPrimLabelSwitch
),
origin_switch_inputs
[
kCNodeSwitchCond
]};
std
::
vector
<
AnfNodePtr
>
new_switch_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSwitchOpName
)),
origin_switch_inputs
[
kCNodeSwitchCond
]};
for
(
size_t
i
=
0
;
i
<
branch_partial
.
size
();
++
i
)
{
//
2
.1 branch kernel graph and args
//
3
.1 branch kernel graph and args
CNodePtr
partial
;
KernelGraphPtr
branch_fg
;
VectorRef
call_args
;
std
::
tie
(
partial
,
branch_fg
,
call_args
)
=
ParsePartial
(
NOT_NULL
(
origin_switch_inputs
[
i
]));
// 2.2 add depend relationship
InsertControlDependToGraph
(
kg
,
cur_node
,
NOT_NULL
(
back_label
));
// 2.3 recurse sub graph
CNodePtr
branch_label
=
ProcessKernelGraph
(
NOT_NULL
(
branch_fg
),
cur_node
,
back_label
,
call_args
,
memo
);
std
::
tie
(
partial
,
branch_fg
)
=
ParsePartial
(
NOT_NULL
(
origin_switch_inputs
[
i
]));
// 3.2 recurse sub graph
CNodePtr
branch_label
=
ProcessKernelGraph
(
NOT_NULL
(
branch_fg
),
cur_node
,
back_label
,
memo
);
new_switch_inputs
.
push_back
(
branch_label
);
}
new_switch_inputs
.
insert
(
new_switch_inputs
.
end
(),
branch_partial
.
begin
(),
branch_partial
.
end
());
...
...
@@ -252,7 +279,7 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
MS_LOG
(
INFO
)
<<
"success process switch layer "
<<
cur_node
->
DebugString
();
}
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
,
VectorRef
>
AscendControlParser
::
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
)
{
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
>
AscendControlParser
::
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
)
{
if
(
!
node
.
get
()
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Switch branches must be partial, node: "
<<
node
->
DebugString
();
}
...
...
@@ -263,9 +290,8 @@ std::tuple<CNodePtr, KernelGraphPtr, VectorRef> AscendControlParser::ParsePartia
}
auto
partial_inputs
=
partial_cnode
->
inputs
();
auto
branch_kg
=
GetValueNode
<
KernelGraphPtr
>
(
partial_inputs
[
kCNodePartialFunc
]);
auto
call_args
=
GetCallArgs
(
partial_inputs
.
begin
()
+
kCNodePartialFunc
+
1
,
partial_inputs
.
end
());
return
{
partial_cnode
,
branch_kg
,
call_args
};
return
{
partial_cnode
,
branch_kg
};
}
void
AscendControlParser
::
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
...
...
@@ -289,31 +315,199 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
InsertDependToGraph
(
kg
,
NOT_NULL
(
assign_node
));
}
size_t
AscendControlParser
::
SetChildGraphInput
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
node
,
size_t
input_index
)
{
auto
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
node
);
if
(
output_num
>
1
&&
!
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimTupleGetItem
))
{
return
input_index
+
output_num
;
NotNull
<
AnfNodePtr
>
AscendControlParser
::
GetRealInput
(
NotNull
<
KernelGraphPtr
>
from_graph
,
NotNull
<
KernelGraphPtr
>
to_graph
,
NotNull
<
AnfNodePtr
>
param
)
{
std
::
set
<
AnfNodePtr
>
args_list
=
to_graph
->
GetRealInput
(
param
);
for
(
auto
arg
:
args_list
)
{
if
(
arg
->
func_graph
()
==
from_graph
.
get
())
{
return
NOT_NULL
(
arg
);
}
}
MS_LOG
(
EXCEPTION
)
<<
to_graph
->
ToString
()
<<
" input "
<<
param
->
DebugString
()
<<
" not from "
<<
from_graph
->
ToString
();
}
auto
&
graph_inputs
=
kg
->
inputs
();
if
(
input_index
>=
graph_inputs
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"input_index "
<<
input_index
<<
" out of range size "
<<
graph_inputs
.
size
();
void
AscendControlParser
::
LinkArgsToParam
(
NotNull
<
KernelGraphPtr
>
to_graph
,
NotNull
<
KernelGraphPtr
>
target_graph
,
NotNull
<
AnfNodePtr
>
arg
,
NotNull
<
AnfNodePtr
>
param
)
{
if
(
IsPrimitiveCNode
(
arg
,
prim
::
kPrimMakeTuple
)
&&
IsPrimitiveCNode
(
param
,
prim
::
kPrimMakeTuple
))
{
MS_LOG
(
INFO
)
<<
"Arg "
<<
arg
->
DebugString
()
<<
" Param "
<<
param
->
DebugString
()
<<
" is a tuple"
;
CNodePtr
cnode_arg
=
arg
.
get
()
->
cast
<
CNodePtr
>
();
CNodePtr
cnode_param
=
param
.
get
()
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode_arg
);
MS_EXCEPTION_IF_NULL
(
cnode_param
);
if
(
cnode_arg
->
size
()
!=
cnode_param
->
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Arg "
<<
arg
->
DebugString
()
<<
" size "
<<
cnode_arg
->
size
()
<<
" but Param "
<<
param
->
DebugString
()
<<
" size "
<<
cnode_param
->
size
();
}
for
(
size_t
i
=
1
;
i
<
cnode_param
->
size
();
++
i
)
{
LinkArgsToParam
(
to_graph
,
target_graph
,
NOT_NULL
(
cnode_arg
->
input
(
i
)),
NOT_NULL
(
cnode_param
->
input
(
i
)));
}
}
else
if
(
arg
->
isa
<
CNode
>
())
{
InsertAssignToGraph
(
target_graph
,
arg
,
param
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Arg "
<<
arg
->
DebugString
()
<<
" Param "
<<
param
->
DebugString
()
<<
" unknown type."
;
}
auto
backend_parameter
=
graph_inputs
[
input_index
];
if
(
node
.
get
()
->
isa
<
Parameter
>
())
{
MS_EXCEPTION_IF_NULL
(
backend_parameter
);
MS_LOG
(
INFO
)
<<
"Reuse node ["
<<
node
->
DebugString
()
<<
"], old node["
<<
backend_parameter
->
DebugString
()
<<
"] will be replaced."
;
kg
->
ReplaceNode
(
backend_parameter
,
node
);
return
input_index
;
}
void
AscendControlParser
::
ExecutorValidate
(
NotNull
<
KernelGraphPtr
>
root_graph
)
{
std
::
set
<
KernelGraphPtr
>
memo
;
(
void
)
RecurseGraph
(
nullptr
,
nullptr
,
root_graph
,
NOT_NULL
(
&
memo
));
}
std
::
vector
<
CNodePtr
>
AscendControlParser
::
RecurseGraph
(
const
CNodePtr
&
cur_label_goto
,
const
CNodePtr
&
end_label_goto
,
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
MS_LOG
(
INFO
)
<<
"graph:"
<<
graph
->
graph_id
()
<<
" start"
;
auto
print_vector
=
[
&
](
std
::
vector
<
CNodePtr
>
vec
)
->
void
{
MS_LOG
(
INFO
)
<<
"graph:"
<<
graph
->
graph_id
()
<<
"execution order"
;
for
(
size_t
i
=
0
;
i
<
vec
.
size
();
i
++
)
{
MS_LOG
(
INFO
)
<<
"["
<<
i
<<
"]["
<<
vec
[
i
]
->
DebugString
()
<<
"]"
;
}
};
if
(
memo
->
find
(
graph
)
!=
memo
->
end
())
{
return
{};
}
memo
->
insert
(
graph
.
get
());
graph
->
SetExecOrderByDefault
();
std
::
vector
<
CNodePtr
>
cnodes
=
graph
->
execution_order
();
std
::
map
<
uint32_t
,
CNodePtr
>
label_map
;
std
::
map
<
CNodePtr
,
std
::
vector
<
uint32_t
>>
label_switch_map
;
std
::
tie
(
label_map
,
label_switch_map
)
=
GetLabelNode
(
cnodes
);
std
::
vector
<
CNodePtr
>
execution_order
;
for
(
auto
&
node
:
cnodes
)
{
execution_order
.
push_back
(
node
);
if
(
node
==
graph
->
get_end_goto
())
{
continue
;
}
auto
label_iter
=
std
::
find_if
(
label_map
.
begin
(),
label_map
.
end
(),
[
node
](
const
std
::
map
<
uint32_t
,
CNodePtr
>::
value_type
iter
)
{
return
iter
.
second
==
node
;
});
if
(
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimLabelGoto
))
{
if
(
!
CheckLabelIndex
(
label_iter
->
first
,
0
,
label_iter
->
second
,
graph
))
{
MS_LOG
(
EXCEPTION
)
<<
"Check label index fail"
;
}
auto
child_graph
=
graph
->
child_graph_order
()[
label_iter
->
first
];
if
(
child_graph
==
graph
->
parent_graph
())
{
continue
;
}
std
::
map
<
uint32_t
,
CNodePtr
>
child_label_map
;
std
::
tie
(
child_label_map
,
std
::
ignore
)
=
GetLabelNode
(
child_graph
->
execution_order
());
auto
child_execution_order
=
RecurseGraph
(
child_label_map
.
begin
()
->
second
,
child_label_map
.
rbegin
()
->
second
,
NOT_NULL
(
child_graph
),
memo
);
execution_order
.
insert
(
execution_order
.
end
(),
child_execution_order
.
begin
(),
child_execution_order
.
end
());
}
else
if
(
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimLabelSwitch
))
{
std
::
vector
<
uint32_t
>
label_list
=
label_switch_map
.
find
(
node
)
->
second
;
std
::
reverse
(
label_list
.
begin
(),
label_list
.
end
());
for
(
size_t
i
=
0
;
i
<
label_list
.
size
();
++
i
)
{
if
(
!
CheckLabelIndex
(
label_iter
->
first
+
i
,
label_list
[
i
],
label_iter
->
second
,
graph
))
{
MS_LOG
(
EXCEPTION
)
<<
"Check label index fail"
;
}
auto
child_graph
=
graph
->
child_graph_order
()[
label_iter
->
first
+
i
];
if
(
child_graph
==
graph
->
parent_graph
())
{
continue
;
}
std
::
map
<
uint32_t
,
CNodePtr
>
child_label_map
;
std
::
tie
(
child_label_map
,
std
::
ignore
)
=
GetLabelNode
(
child_graph
->
execution_order
());
auto
child_execution_order
=
RecurseGraph
(
child_label_map
.
begin
()
->
second
,
child_label_map
.
rbegin
()
->
second
,
NOT_NULL
(
child_graph
),
memo
);
execution_order
.
insert
(
execution_order
.
end
(),
child_execution_order
.
begin
(),
child_execution_order
.
end
());
}
}
}
InsertAssignToGraph
(
kg
,
node
,
NOT_NULL
(
backend_parameter
));
return
input_index
+
1
;
graph
->
set_execution_order
(
execution_order
);
print_vector
(
graph
->
execution_order
());
return
execution_order
;
}
void
AscendControlParser
::
SetSubGraphInput
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
from_graph_call_node
,
const
VectorRef
&
args
)
{}
bool
AscendControlParser
::
CheckLabelIndex
(
uint32_t
order_index
,
uint32_t
label_index
,
const
CNodePtr
&
cur_label
,
NotNull
<
KernelGraphPtr
>
graph
)
{
// check index and child order size
if
(
graph
->
child_graph_order
().
size
()
<=
static_cast
<
size_t
>
(
order_index
))
{
MS_LOG
(
EXCEPTION
)
<<
"Child graph order is wrong, graph "
<<
graph
->
ToString
()
<<
" child graph size "
<<
graph
->
child_graph_order
().
size
()
<<
" goto index "
<<
order_index
;
}
if
(
AnfAlgo
::
CheckPrimitiveType
(
cur_label
,
prim
::
kPrimLabelGoto
))
{
// check label_goto and start_label in child graph
if
(
!
AnfAlgo
::
HasNodeAttr
(
kAttrLabelIndex
,
cur_label
))
{
MS_LOG
(
EXCEPTION
)
<<
"LabelSetKernel has no attr label_index"
;
}
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
cur_label
);
MS_EXCEPTION_IF_NULL
(
primitive
);
uint32_t
label_goto_index
=
GetValue
<
uint32_t
>
(
primitive
->
GetAttr
(
kAttrLabelIndex
));
label_index
=
label_goto_index
;
}
// get start_label_set_index of child graph
auto
child_graph
=
graph
->
child_graph_order
()[
order_index
];
MS_EXCEPTION_IF_NULL
(
child_graph
);
auto
start_label_set
=
child_graph
->
get_start_label
();
if
(
!
AnfAlgo
::
HasNodeAttr
(
kAttrLabelIndex
,
start_label_set
))
{
MS_LOG
(
EXCEPTION
)
<<
"LabelSetKernel has no attr label_index"
;
}
auto
start_primitive
=
AnfAlgo
::
GetCNodePrimitive
(
start_label_set
);
MS_EXCEPTION_IF_NULL
(
start_primitive
);
uint32_t
start_label_set_index
=
GetValue
<
uint32_t
>
(
start_primitive
->
GetAttr
(
kAttrLabelIndex
));
if
(
label_index
!=
start_label_set_index
)
{
MS_LOG
(
WARNING
)
<<
cur_label
->
DebugString
()
<<
" index "
<<
label_index
<<
" but "
<<
start_label_set
->
DebugString
()
<<
" index "
<<
start_label_set_index
<<
" current child graph order : "
<<
order_index
;
return
false
;
}
return
true
;
}
std
::
tuple
<
std
::
map
<
uint32_t
,
CNodePtr
>
,
std
::
map
<
CNodePtr
,
std
::
vector
<
uint32_t
>>>
AscendControlParser
::
GetLabelNode
(
const
std
::
vector
<
CNodePtr
>
&
nodes
)
{
std
::
map
<
uint32_t
,
CNodePtr
>
label_map
;
std
::
map
<
CNodePtr
,
std
::
vector
<
uint32_t
>>
label_switch_map
;
// record child graph
uint32_t
index
=
0
;
for
(
auto
&
node
:
nodes
)
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimLabelGoto
))
{
label_map
[
index
]
=
node
;
++
index
;
}
else
if
(
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimLabelSwitch
))
{
if
(
!
AnfAlgo
::
HasNodeAttr
(
kAttrLabelSwitchList
,
node
))
{
MS_LOG
(
EXCEPTION
)
<<
"LabelSwitchKernel has no attr label_switch_list"
;
}
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
node
);
MS_EXCEPTION_IF_NULL
(
primitive
);
std
::
vector
<
uint32_t
>
label_list
=
GetValue
<
std
::
vector
<
uint32_t
>>
(
primitive
->
GetAttr
(
kAttrLabelSwitchList
));
label_switch_map
.
insert
({
node
,
label_list
});
for
(
size_t
i
=
0
;
i
<
label_list
.
size
();
++
i
)
{
label_map
[
index
]
=
node
;
++
index
;
}
}
}
return
{
label_map
,
label_switch_map
};
}
void
AscendControlParser
::
UpdateChildGraphOrder
(
NotNull
<
KernelGraphPtr
>
kg
)
{
MS_LOG
(
INFO
)
<<
"graph id:"
<<
kg
->
graph_id
();
kg
->
SetExecOrderByDefault
();
auto
call_nodes
=
kg
->
FindNodeByPrimitive
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimCall
->
name
()));
std
::
vector
<
KernelGraphPtr
>
child_graph_order
;
for
(
auto
&
call_node
:
call_nodes
)
{
MS_EXCEPTION_IF_NULL
(
call_node
);
auto
call_child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
call_node
->
cast
<
CNodePtr
>
());
for
(
const
auto
&
child_graph
:
call_child_graphs
)
{
MS_EXCEPTION_IF_NULL
(
child_graph
);
if
(
child_graph
!=
kg
->
parent_graph
())
{
child_graph
->
set_parent_graph
(
kg
.
get
());
}
child_graph_order
.
push_back
(
child_graph
);
}
}
for
(
size_t
i
=
0
;
i
<
child_graph_order
.
size
();
i
++
)
{
MS_LOG
(
INFO
)
<<
"child graph["
<<
i
<<
"][id:"
<<
child_graph_order
[
i
]
->
graph_id
()
<<
"]"
;
}
kg
->
set_child_graph_order
(
child_graph_order
);
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/ascend_control_parser.h
浏览文件 @
5aae0d91
...
...
@@ -17,6 +17,7 @@
#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
#include <set>
#include <map>
#include <vector>
#include <tuple>
#include "session/kernel_graph.h"
...
...
@@ -28,31 +29,44 @@ namespace session {
class
AscendControlParser
{
public:
static
void
ChildGraphDataAssign
(
const
std
::
map
<
uint32_t
,
KernelGraphPtr
>
&
graph_id_map
);
static
void
LinkGraph
(
NotNull
<
KernelGraphPtr
>
kg
);
static
void
InsertDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
attch_node
);
static
void
InsertControlDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
first_node
,
NotNull
<
AnfNodePtr
>
second_node
);
static
void
ExecutorValidate
(
NotNull
<
KernelGraphPtr
>
root_graph
);
static
void
UpdateChildGraphOrder
(
NotNull
<
KernelGraphPtr
>
kg
);
private:
static
NotNull
<
CNodePtr
>
ProcessKernelGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
last_node
,
const
CNodePtr
&
last_label
,
const
VectorRef
&
args
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
const
CNodePtr
&
last_label
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
void
RecurseCall
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
const
CNodePtr
&
next_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
void
RecurseSwitch
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
static
void
RecurseSwitch
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
const
CNodePtr
&
next_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
void
RecurseSwitchLayer
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
static
void
RecurseSwitchLayer
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
const
CNodePtr
&
next_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
std
::
vector
<
CNodePtr
>
GetCNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
in
);
static
void
LinkParentGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
from_graph_call_node
,
const
CNodePtr
&
last_label
,
const
VectorRef
&
args
);
static
void
SetSubGraphInput
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
from_graph_call_node
,
const
VectorRef
&
args
);
static
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
,
VectorRef
>
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
);
const
CNodePtr
&
last_label
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
>
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
);
static
void
LinkArgsToParam
(
NotNull
<
KernelGraphPtr
>
to_graph
,
NotNull
<
KernelGraphPtr
>
target_graph
,
NotNull
<
AnfNodePtr
>
arg
,
NotNull
<
AnfNodePtr
>
param
);
static
NotNull
<
AnfNodePtr
>
GetRealInput
(
NotNull
<
KernelGraphPtr
>
from_graph
,
NotNull
<
KernelGraphPtr
>
to_graph
,
NotNull
<
AnfNodePtr
>
param
);
static
void
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
static
size_t
SetChildGraphInput
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
node
,
size_t
input_index
);
static
CNodePtr
GetNextRealKernel
(
std
::
vector
<
CNodePtr
>
list
,
size_t
start
);
// root graph order
static
std
::
tuple
<
std
::
map
<
uint32_t
,
CNodePtr
>
,
std
::
map
<
CNodePtr
,
std
::
vector
<
uint32_t
>>>
GetLabelNode
(
const
std
::
vector
<
CNodePtr
>
&
nodes
);
static
bool
CheckLabelIndex
(
uint32_t
order_index
,
uint32_t
label_index
,
const
CNodePtr
&
cnode
,
NotNull
<
KernelGraphPtr
>
graph
);
static
std
::
vector
<
CNodePtr
>
RecurseGraph
(
const
CNodePtr
&
cur_label_goto
,
const
CNodePtr
&
end_label_goto
,
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
constexpr
size_t
kCNodePrim
=
0
;
static
constexpr
size_t
kCNodeCallArg
=
1
;
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
5aae0d91
...
...
@@ -177,10 +177,6 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
for
(
size_t
i
=
0
;
i
<
cnodes
.
size
();
i
++
)
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
cnodes
[
i
],
prim
::
kPrimCall
)
&&
!
AnfAlgo
::
IsSwitchCall
(
cnodes
[
i
]))
{
auto
call_kernel_graph
=
AnfAlgo
::
GetCallNodeKernelGraph
(
cnodes
[
i
]);
// if graph is the true branch of while,no need split graph
if
(
call_kernel_graph
.
size
()
==
1
&&
call_kernel_graph
[
0
]
==
cur_graph
.
parent_graph
())
{
continue
;
}
auto
prev_call_list
=
std
::
vector
<
CNodePtr
>
(
cnodes
.
begin
()
+
after_call_index
,
cnodes
.
begin
()
+
i
);
auto
call_list
=
std
::
vector
<
CNodePtr
>
(
1
,
cnodes
[
i
]);
after_call_index
=
i
+
1
;
...
...
@@ -195,10 +191,10 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
void
UpdateRealInput
(
KernelGraph
*
graph
)
{
static
void
UpdateRealInput
(
KernelGraph
*
graph
)
{
auto
call_nodes
=
graph
->
FindNodeByPrimitive
(
prim
::
kPrimCall
);
auto
bind_call_
partial
_with_parameter
=
[
&
](
const
std
::
vector
<
AnfNodePtr
>
&
parameters
,
const
std
::
vector
<
AnfNodePtr
>
&
args
,
KernelGraph
*
child_graph
)
->
void
{
auto
bind_call_
arg
_with_parameter
=
[
&
](
const
std
::
vector
<
AnfNodePtr
>
&
parameters
,
const
std
::
vector
<
AnfNodePtr
>
&
args
,
KernelGraph
*
child_graph
)
->
void
{
MS_EXCEPTION_IF_NULL
(
child_graph
);
MS_LOG
(
INFO
)
<<
"start bind parameter of child graph:"
<<
child_graph
->
graph_id
();
if
(
args
.
empty
())
{
...
...
@@ -208,8 +204,21 @@ void UpdateRealInput(KernelGraph *graph) {
MS_LOG
(
EXCEPTION
)
<<
"graph:"
<<
child_graph
->
graph_id
()
<<
" parameters size:"
<<
parameters
.
size
()
<<
" and args size:"
<<
args
.
size
()
<<
" not equal!"
;
}
child_graph
->
SetExecOrderByDefault
();
for
(
size_t
i
=
0
;
i
<
parameters
.
size
();
i
++
)
{
MS_LOG
(
INFO
)
<<
"bind paramreter:"
<<
parameters
[
i
]
->
DebugString
()
<<
" ,arg:"
<<
args
[
i
]
->
DebugString
();
if
(
args
[
i
]
==
parameters
[
i
])
{
child_graph
->
SetRealInput
(
parameters
[
i
],
args
[
i
]);
MS_LOG
(
INFO
)
<<
"Parameter and arg are same"
;
continue
;
}
// if arg is a parameter ,then reuse this parameter
if
(
args
[
i
]
->
isa
<
Parameter
>
())
{
MS_LOG
(
INFO
)
<<
"Parameter:"
<<
parameters
[
i
]
->
DebugString
()
<<
" of graph:"
<<
child_graph
->
graph_id
()
<<
" reuse parameter:"
<<
args
[
i
]
->
DebugString
()
<<
" of graph:"
<<
AnfAlgo
::
GetGraphId
(
args
[
i
].
get
());
child_graph
->
ReplaceNode
(
parameters
[
i
],
args
[
i
]);
continue
;
}
child_graph
->
SetRealInput
(
parameters
[
i
],
args
[
i
]);
}
};
...
...
@@ -218,9 +227,10 @@ void UpdateRealInput(KernelGraph *graph) {
auto
child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
call_node
);
if
(
child_graphs
.
size
()
==
1
)
{
MS_EXCEPTION_IF_NULL
(
child_graphs
[
0
]);
bind_call_partial_with_parameter
(
child_graphs
[
0
]
->
inputs
(),
std
::
vector
<
AnfNodePtr
>
(
call_node
->
inputs
().
begin
()
+
2
,
call_node
->
inputs
().
end
()),
child_graphs
[
0
].
get
());
std
::
vector
<
AnfNodePtr
>
real_args
=
std
::
vector
<
AnfNodePtr
>
(
call_node
->
inputs
().
begin
()
+
2
,
call_node
->
inputs
().
end
());
std
::
vector
<
AnfNodePtr
>
child_inputs
=
child_graphs
[
0
]
->
inputs
();
bind_call_arg_with_parameter
(
child_inputs
,
real_args
,
child_graphs
[
0
].
get
());
call_node
->
set_inputs
(
std
::
vector
<
AnfNodePtr
>
(
call_node
->
inputs
().
begin
(),
call_node
->
inputs
().
begin
()
+
2
));
}
else
if
(
child_graphs
.
size
()
==
2
)
{
auto
get_partial_args
=
[
&
](
size_t
input_index
)
->
std
::
vector
<
AnfNodePtr
>
{
...
...
@@ -237,8 +247,8 @@ void UpdateRealInput(KernelGraph *graph) {
std
::
vector
<
AnfNodePtr
>
(
partial_cnode
->
inputs
().
begin
(),
partial_cnode
->
inputs
().
begin
()
+
2
));
return
ret
;
};
bind_call_
partial
_with_parameter
(
child_graphs
[
0
]
->
inputs
(),
get_partial_args
(
2
),
child_graphs
[
0
].
get
());
bind_call_
partial
_with_parameter
(
child_graphs
[
1
]
->
inputs
(),
get_partial_args
(
3
),
child_graphs
[
1
].
get
());
bind_call_
arg
_with_parameter
(
child_graphs
[
0
]
->
inputs
(),
get_partial_args
(
2
),
child_graphs
[
0
].
get
());
bind_call_
arg
_with_parameter
(
child_graphs
[
1
]
->
inputs
(),
get_partial_args
(
3
),
child_graphs
[
1
].
get
());
}
}
}
...
...
@@ -248,6 +258,11 @@ void RecurseToUpdateCallRealInput(KernelGraph *graph) {
MS_LOG
(
INFO
)
<<
"start graph id:"
<<
graph
->
graph_id
();
graph
->
UpdateCallRealInput
();
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
if
(
child_graph
==
graph
->
parent_graph
())
{
MS_LOG
(
INFO
)
<<
"Child graph:"
<<
child_graph
->
graph_id
()
<<
",parent graph:"
<<
graph
->
parent_graph
()
->
graph_id
();
continue
;
}
RecurseToUpdateCallRealInput
(
child_graph
.
get
());
}
}
...
...
@@ -265,31 +280,31 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
GraphId
AscendSession
::
CompileGraph
(
NotNull
<
FuncGraphPtr
>
func_graph
)
{
MS_LOG
(
INFO
)
<<
"start"
;
auto
graph
=
ConstructKernelGraph
(
func_graph
);
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
// split switch
SplitGraphs
(
graph
);
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
// insert goto labels and label_sets
LinkChildGraphs
(
NOT_NULL
(
graph
));
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
// resource initialize
InitRuntimeResource
();
// assign label
AssignLabel
(
NOT_NULL
(
graph
));
if
(
!
graph
->
executable
())
{
return
graph
->
graph_id
();
}
for
(
auto
iter
:
graphs_
)
{
if
(
iter
.
second
==
graph
)
{
MS_LOG
(
INFO
)
<<
"Entry graph "
<<
graph
->
ToString
()
<<
" graph id "
<<
graph
->
graph_id
();
final_graph_id_
=
graph
->
graph_id
();
}
MS_LOG
(
INFO
)
<<
"CompileChildGraph "
<<
iter
.
second
->
ToString
();
CompileChildGraph
(
iter
.
second
);
}
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
// recurse compile child graph
RecurseCompileGraph
(
graph
);
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
// root graph valiate,include genearte execute order and so on
RootGraphExecutorValidate
(
NOT_NULL
(
graph
));
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
// adjust kernel
AdjustKernel
(
graph
);
// root graph valiate,include genearte execute order and so on
RootGraphExecutorValidate
(
graph
.
get
());
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
// assign stream
AssignStream
(
graph
);
// build kernel
BuildKernel
(
graph
);
// alloc mem
MemoryAlloc
(
graph
.
get
());
// task generate
...
...
@@ -365,6 +380,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
void
AscendSession
::
CompileChildGraph
(
const
KernelGraphPtr
&
child_graph
)
{
MS_EXCEPTION_IF_NULL
(
child_graph
);
MS_LOG
(
INFO
)
<<
"CompileChildGraph "
<<
child_graph
->
ToString
();
opt
::
AscendBackendIRFusionOptimization
(
child_graph
);
// select kernel build info
SelectKernel
(
*
child_graph
);
...
...
@@ -376,12 +392,14 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id_
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
runtime_instance
->
AssignStaticMemoryInput
(
child_graph
.
get
());
runtime_instance
->
AssignStaticMemoryValueNode
(
child_graph
.
get
());
}
void
AscendSession
::
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
const
outputs
)
{
MS_LOG
(
INFO
)
<<
"start"
;
auto
kernel_graph
=
GetGraph
(
graph_id
);
DumpIR
(
"./run_graph.ir"
,
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
// if none of child graph and no anf output exists
if
(
!
kernel_graph
->
executable
())
{
...
...
@@ -1378,10 +1396,10 @@ void AscendSession::SyncInitialTenosrToDevice() {
}
}
KernelGraphPtr
AscendSession
::
ConstructSplitedGraph
(
const
KernelGraphPtr
&
new_kernel_graph
,
const
std
::
vector
<
CNodePtr
>
&
list
)
{
std
::
vector
<
AnfNodePtr
>
AscendSession
::
ConstructSplitedGraph
(
const
KernelGraphPtr
&
new_kernel_graph
,
const
std
::
vector
<
CNodePtr
>
&
list
)
{
MS_EXCEPTION_IF_NULL
(
new_kernel_graph
);
MS_LOG
(
INFO
)
<<
"start
split
kernel graph:"
<<
new_kernel_graph
->
graph_id
();
MS_LOG
(
INFO
)
<<
"start
contruct splited
kernel graph:"
<<
new_kernel_graph
->
graph_id
();
// count the output of every anf node
std
::
set
<
AnfNodePtr
>
has_output_nodes
;
for
(
auto
&
anf_node
:
list
)
{
...
...
@@ -1390,21 +1408,23 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke
}
}
MS_LOG
(
INFO
)
<<
"Construct input of kernel graph:"
<<
new_kernel_graph
->
graph_id
();
std
::
vector
<
AnfNodePtr
>
call_node_inputs
;
auto
graph_inputs
=
new_kernel_graph
->
MutableInputs
();
MS_EXCEPTION_IF_NULL
(
graph_inputs
);
// create new parameter from cnode
for
(
auto
&
anf_node
:
list
)
{
auto
cnode
=
anf_node
->
cast
<
CNodePtr
>
();
for
(
size_t
input_idx
=
1
;
input_idx
<
cnode
->
inputs
().
size
();
input_idx
++
)
{
auto
input
=
cnode
->
inputs
()[
input_idx
];
MS_EXCEPTION_IF_NULL
(
input
);
if
(
!
input
->
isa
<
CNode
>
())
{
if
(
input
->
isa
<
Parameter
>
())
{
graph_inputs
->
push_back
(
input
);
cnode
->
set_input
(
input_idx
,
input
);
continue
;
}
if
(
AnfAlgo
::
GetGraphId
(
input
.
get
())
!=
new_kernel_graph
->
graph_id
())
{
}
else
if
(
AnfAlgo
::
GetGraphId
(
input
.
get
())
!=
new_kernel_graph
->
graph_id
())
{
auto
new_parameter
=
CreateNewParameterFromCNode
(
input
,
true
,
new_kernel_graph
.
get
());
cnode
->
set_input
(
input_idx
,
new_parameter
);
new_kernel_graph
->
SetRealInput
(
new_parameter
,
input
);
}
call_node_inputs
.
push_back
(
input
);
}
}
MS_LOG
(
INFO
)
<<
"Construct output of kernel graph:"
<<
new_kernel_graph
->
graph_id
();
...
...
@@ -1424,7 +1444,7 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke
new_kernel_graph
->
set_output
(
new_kernel_graph
->
NewCNode
(
make_tuple_inputs
));
}
MS_LOG
(
INFO
)
<<
"end"
;
return
new_kernel_graph
;
return
call_node_inputs
;
}
void
AscendSession
::
SplitGraphs
(
const
KernelGraphPtr
&
root_graph
)
{
...
...
@@ -1438,7 +1458,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL
(
graph
);
auto
apply_list
=
GetCNodes
(
TopoSort
(
graph
->
get_return
()));
// update the root graph child graph order
graph
->
UpdateChildGraphOrder
(
);
AscendControlParser
::
UpdateChildGraphOrder
(
NOT_NULL
(
graph
)
);
// get child list from current graph
std
::
vector
<
std
::
vector
<
CNodePtr
>>
child_graph_lists
=
GetChildList
(
*
graph
,
apply_list
);
auto
bind_new_call_to_new_graph
=
[
&
](
std
::
vector
<
CNodePtr
>
child_graph_list
)
->
AnfNodePtr
{
...
...
@@ -1457,7 +1477,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
for
(
auto
&
child_graph_node
:
child_graph_list
)
{
AnfAlgo
::
SetGraphId
(
child_graph
->
graph_id
(),
child_graph_node
.
get
());
}
ConstructSplitedGraph
(
child_graph
,
child_graph_list
);
auto
call_node_args
=
ConstructSplitedGraph
(
child_graph
,
child_graph_list
);
std
::
copy
(
call_node_args
.
begin
(),
call_node_args
.
end
(),
std
::
back_inserter
(
new_call_input
));
auto
new_call
=
graph
->
NewCNode
(
new_call_input
);
AnfAlgo
::
SetNodeAttr
(
"graph id"
,
MakeValue
(
graph
->
graph_id
()),
new_call
);
return
new_call
;
...
...
@@ -1466,26 +1487,59 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
std
::
list
<
AnfNodePtr
>
depend_input
=
{};
for
(
size_t
call_index
=
0
;
call_index
<
child_graph_lists
.
size
();
call_index
++
)
{
auto
call_node
=
bind_new_call_to_new_graph
(
child_graph_lists
[
call_index
]);
MS_EXCEPTION_IF_NULL
(
call_node
);
// if call node is the last call of true graph,no need create child graph after that
auto
child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
call_node
->
cast
<
CNodePtr
>
());
depend_input
.
push_front
(
call_node
);
if
(
child_graphs
.
size
()
==
1
&&
child_graphs
[
0
]
==
graph
->
parent_graph
())
{
break
;
}
}
depend_input
.
push_front
(
graph
->
NewValueNode
(
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimDepend
->
name
()))));
auto
depend
=
graph
->
NewCNode
(
std
::
vector
<
AnfNodePtr
>
(
depend_input
.
begin
(),
depend_input
.
end
()));
auto
new_return_primitive
=
graph
->
NewValueNode
(
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimReturn
->
name
())));
graph
->
set_return
(
graph
->
NewCNode
({
new_return_primitive
,
depend
}));
AnfNodePtr
pre_call_node
=
nullptr
;
AnfNodePtr
cur_call_node
=
nullptr
;
auto
iter
=
depend_input
.
begin
();
for
(
++
iter
;
iter
!=
depend_input
.
end
();
++
iter
)
{
pre_call_node
=
cur_call_node
;
cur_call_node
=
*
iter
;
if
(
pre_call_node
!=
nullptr
&&
cur_call_node
!=
nullptr
)
{
AscendControlParser
::
InsertControlDependToGraph
(
NOT_NULL
(
graph
),
NOT_NULL
(
cur_call_node
),
NOT_NULL
(
pre_call_node
));
}
}
}
graph
->
UpdateChildGraphOrder
(
);
AscendControlParser
::
UpdateChildGraphOrder
(
NOT_NULL
(
graph
)
);
UpdateRealInput
(
graph
.
get
());
auto
graph_name
=
std
::
string
(
"./kernel-graph-"
).
append
(
std
::
to_string
(
graph
->
graph_id
()));
DumpIR
(
graph_name
,
graph
);
MS_LOG
(
INFO
)
<<
"split graph["
<<
graph
->
graph_id
()
<<
"] end"
;
// recurse to split child graph
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
SplitGraph
(
child_graph
);
if
(
child_graph
!=
graph
->
parent_graph
())
{
SplitGraph
(
child_graph
);
}
}
}
void
AscendSession
::
LinkChildGraphs
(
NotNull
<
KernelGraphPtr
>
graph
)
{
AscendControlParser
::
LinkGraph
(
graph
);
}
void
AscendSession
::
RootGraphExecutorValidate
(
NotNull
<
KernelGraphPtr
>
graph
)
{
AscendControlParser
::
ExecutorValidate
(
graph
);
}
void
AscendSession
::
RecurseCompileGraph
(
const
KernelGraphPtr
&
graph
)
{
CompileChildGraph
(
graph
);
for
(
auto
child_graph
:
graph
->
child_graph_order
())
{
if
(
child_graph
==
graph
->
parent_graph
())
{
continue
;
}
RecurseCompileGraph
(
child_graph
);
}
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
5aae0d91
...
...
@@ -104,10 +104,10 @@ class AscendSession : public SessionBasic {
void
SelectKernelGraphKernel
(
const
KernelGraph
&
graph
)
{}
void
ConvertPredictModel
(
const
KernelGraphPtr
graph
)
{}
void
HardwareOptimizeGraphs
(
const
KernelGraphPtr
graph
)
{}
void
RootGraphExecutorValidate
(
KernelGraph
*
graph
)
{}
void
RecurseUpdateAllChildGraohOrder
(
KernelGraph
*
root_graph
);
KernelGraphPtr
ConstructSplitedGraph
(
const
KernelGraphPtr
&
new_kernel_graph
,
const
std
::
vector
<
CNodePtr
>
&
list
);
void
ChildGraphCommunicationDecrease
(
std
::
vector
<
std
::
vector
<
AnfNodePtr
>>
*
anf_node_lists
);
void
RootGraphExecutorValidate
(
NotNull
<
KernelGraphPtr
>
graph
);
std
::
vector
<
AnfNodePtr
>
ConstructSplitedGraph
(
const
KernelGraphPtr
&
new_kernel_graph
,
const
std
::
vector
<
CNodePtr
>
&
list
);
void
RecurseCompileGraph
(
const
KernelGraphPtr
&
graph
);
// merge execution order list of child graphs
void
MergeGraphExecOrder
();
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
5aae0d91
...
...
@@ -165,6 +165,21 @@ void KernelGraph::SetExecOrderByDefault() {
}
}
CheckLoop
();
// resort start label / end goto
std
::
vector
<
CNodePtr
>
re_order
;
if
(
start_label_
!=
nullptr
)
{
re_order
.
push_back
(
start_label_
);
}
for
(
auto
&
node
:
execution_order_
)
{
if
(
node
==
start_label_
||
node
==
end_goto_
)
{
continue
;
}
re_order
.
push_back
(
node
);
}
if
(
end_goto_
!=
nullptr
)
{
re_order
.
push_back
(
end_goto_
);
}
execution_order_
=
re_order
;
}
void
KernelGraph
::
CheckLoop
()
{
...
...
@@ -360,7 +375,8 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode
void
KernelGraph
::
FrontBackendlMapUpdate
(
const
AnfNodePtr
&
old_backend_anf
,
const
AnfNodePtr
&
new_backend_anf
)
{
MS_EXCEPTION_IF_NULL
(
old_backend_anf
);
MS_EXCEPTION_IF_NULL
(
new_backend_anf
);
if
(
old_backend_anf
.
get
()
==
new_backend_anf
.
get
())
{
if
(
old_backend_anf
==
new_backend_anf
)
{
MS_LOG
(
INFO
)
<<
"old:"
<<
old_backend_anf
->
DebugString
()
<<
",new:"
<<
new_backend_anf
->
DebugString
();
MS_LOG
(
EXCEPTION
)
<<
"old can't be same with new"
;
}
if
(
backend_front_anf_map_
.
find
(
old_backend_anf
)
==
backend_front_anf_map_
.
end
())
{
...
...
@@ -569,32 +585,52 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
MS_EXCEPTION_IF_NULL
(
new_anf_node
);
MS_EXCEPTION_IF_NULL
(
inputs_
);
auto
it
=
node_output_edges_
.
find
(
old_anf_node
);
if
(
it
=
=
node_output_edges_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Can't find anf node in node_output_edges map"
;
}
auto
&
outputs
=
it
->
second
;
for
(
auto
&
output_node
:
outputs
)
{
auto
output_cnode
=
output_node
.
first
->
cast
<
CNodePtr
>
(
);
MS_EXCEPTION_IF_NULL
(
output_cnode
);
auto
&
output_node_inputs
=
output_cnode
->
inputs
();
for
(
size_t
i
=
1
;
i
<
output_node_inputs
.
size
();
i
++
)
{
if
(
output_node_inputs
[
i
]
==
old_anf_node
)
{
output_cnode
->
set_input
(
i
,
new_anf_node
);
if
(
it
!
=
node_output_edges_
.
end
())
{
const
auto
&
outputs
=
it
->
second
;
for
(
auto
&
output_node
:
outputs
)
{
MS_EXCEPTION_IF_NULL
(
output_node
.
first
)
;
auto
output_cnode
=
output_node
.
first
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
output_cnode
);
const
auto
&
output_node_inputs
=
output_cnode
->
inputs
(
);
for
(
size_t
i
=
1
;
i
<
output_node_inputs
.
size
();
i
++
)
{
if
(
output_node_inputs
[
i
]
==
old_anf_node
)
{
output_cnode
->
set_input
(
i
,
new_anf_node
);
}
}
}
// update graph inputs
for
(
size_t
i
=
0
;
i
<
inputs_
->
size
();
i
++
)
{
if
((
*
inputs_
)[
i
]
==
old_anf_node
)
{
(
*
inputs_
)[
i
]
=
new_anf_node
;
break
;
// update graph inputs
for
(
size_t
i
=
0
;
i
<
inputs_
->
size
();
i
++
)
{
if
((
*
inputs_
)[
i
]
==
old_anf_node
)
{
MS_LOG
(
INFO
)
<<
"Replace input of graph:"
<<
graph_id_
<<
", old graph input: "
<<
old_anf_node
->
DebugString
()
<<
",new graph input:"
<<
new_anf_node
->
DebugString
();
(
*
inputs_
)[
i
]
=
new_anf_node
;
break
;
}
}
MS_LOG
(
INFO
)
<<
"Inputs of graph id:"
<<
graph_id
();
for
(
size_t
i
=
0
;
i
<
inputs
().
size
();
i
++
)
{
MS_LOG
(
INFO
)
<<
"["
<<
i
<<
"]:"
<<
inputs
()[
i
]
->
DebugString
();
}
}
// update front to backend map
FrontBackendlMapUpdate
(
old_anf_node
,
new_anf_node
);
// update output depend relations
node_output_edges_
[
new_anf_node
]
=
it
->
second
;
(
void
)
node_output_edges_
.
erase
(
old_anf_node
);
}
// update graph inputs in child graph
auto
it_real_inputs
=
real_inputs_
.
find
(
old_anf_node
);
if
(
it_real_inputs
!=
real_inputs_
.
end
())
{
// insert new parameter to map
auto
iter
=
real_inputs_
.
find
(
new_anf_node
);
if
(
iter
!=
real_inputs_
.
end
())
{
MS_LOG
(
WARNING
)
<<
new_anf_node
->
DebugString
()
<<
" already exist in real inputs, will be rewrited."
;
iter
->
second
=
it_real_inputs
->
second
;
}
else
{
real_inputs_
[
new_anf_node
]
=
it_real_inputs
->
second
;
}
// erase old parameter in map
real_inputs_
.
erase
(
old_anf_node
);
}
// update front to backend map
FrontBackendlMapUpdate
(
old_anf_node
,
new_anf_node
);
// update output depend relations
node_output_edges_
[
new_anf_node
]
=
it
->
second
;
(
void
)
node_output_edges_
.
erase
(
old_anf_node
);
}
void
KernelGraph
::
UpdateExecuteKernelStreamLabel
()
{
...
...
@@ -603,29 +639,6 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() {
}
}
void
KernelGraph
::
UpdateChildGraphOrder
()
{
MS_LOG
(
INFO
)
<<
"graph id:"
<<
graph_id_
;
auto
call_nodes
=
FindNodeByPrimitive
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimCall
->
name
()));
for
(
auto
&
old_child_graph
:
child_graph_order_
)
{
old_child_graph
->
set_parent_graph
(
nullptr
);
}
child_graph_order_
.
clear
();
for
(
auto
&
call_node
:
call_nodes
)
{
MS_EXCEPTION_IF_NULL
(
call_node
);
auto
call_child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
call_node
->
cast
<
CNodePtr
>
());
for
(
const
auto
&
child_graph
:
call_child_graphs
)
{
MS_EXCEPTION_IF_NULL
(
child_graph
);
if
(
child_graph
!=
parent_graph
())
{
child_graph
->
set_parent_graph
(
shared_from_this
()
->
cast
<
std
::
shared_ptr
<
KernelGraph
>>
());
child_graph_order_
.
push_back
(
child_graph
);
}
}
}
for
(
size_t
i
=
0
;
i
<
child_graph_order_
.
size
();
i
++
)
{
MS_LOG
(
INFO
)
<<
"child graph["
<<
i
<<
"][id:"
<<
child_graph_order_
[
i
]
->
graph_id
()
<<
"]"
;
}
}
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
KernelGraph
::
GetLeafGraphOrder
()
{
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
leaf_graph_order
;
if
(
IsLeafGraph
())
{
...
...
@@ -643,9 +656,8 @@ std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
bool
KernelGraph
::
IsLeafGraph
()
const
{
return
child_graph_order_
.
empty
();
}
std
::
vector
<
CNodePtr
>
KernelGraph
::
FindNodeByPrimitive
(
const
PrimitivePtr
&
primitive
)
const
{
auto
anf_list
=
TopoSort
(
get_return
());
std
::
vector
<
CNodePtr
>
result
;
for
(
const
auto
&
anf
:
anf_list
)
{
for
(
const
auto
&
anf
:
execution_order_
)
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
anf
,
primitive
)
&&
AnfAlgo
::
GetGraphId
(
anf
.
get
())
==
graph_id_
)
{
result
.
push_back
(
anf
->
cast
<
CNodePtr
>
());
}
...
...
@@ -653,14 +665,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
return
result
;
}
std
::
set
<
AnfNodePtr
>
KernelGraph
::
GetRealInput
(
const
AnfNodePtr
&
parameter
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
if
(
real_inputs_
.
find
(
parameter
)
==
real_inputs_
.
end
())
{
return
{};
}
return
real_inputs_
[
parameter
];
}
void
KernelGraph
::
SetRealInput
(
const
AnfNodePtr
&
parameter
,
const
AnfNodePtr
&
arg
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
MS_EXCEPTION_IF_NULL
(
arg
);
...
...
@@ -674,37 +678,41 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar
(
void
)
args
.
insert
(
arg
);
}
std
::
set
<
AnfNodePtr
>
KernelGraph
::
GetRealInput
(
const
AnfNodePtr
&
parameter
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
auto
iter
=
real_inputs_
.
find
(
parameter
);
if
(
iter
!=
real_inputs_
.
end
())
{
return
iter
->
second
;
}
MS_LOG
(
EXCEPTION
)
<<
parameter
->
DebugString
()
<<
" not found."
;
}
void
KernelGraph
::
UpdateCallRealInput
()
{
MS_LOG
(
INFO
)
<<
"Update graph id: "
<<
graph_id_
;
for
(
auto
&
it
:
real_inputs_
)
{
auto
&
parameter
=
it
.
first
;
MS_EXCEPTION_IF_NULL
(
parameter
);
auto
&
real_inputs
=
it
.
second
;
std
::
set
<
AnfNodePtr
>
new_real_inputs
;
std
::
vector
<
AnfNodePtr
>
new_real_inputs
;
std
::
set
<
AnfNodePtr
>
erase_real_inputs
;
for
(
auto
&
real_input
:
real_inputs
)
{
// if real input is a call node ,find the child graph output act as the new real input
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
real_input
,
0
);
MS_EXCEPTION_IF_NULL
(
item_with_index
.
first
);
if
(
AnfAlgo
::
CheckPrimitiveType
(
item_with_index
.
first
,
prim
::
kPrimCall
))
{
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" erase real input:"
<<
item_with_index
.
first
->
DebugString
();
(
void
)
erase_real_inputs
.
insert
(
item_with_index
.
first
);
auto
call_node_outputs
=
GetCallRealOutputs
(
item_with_index
.
first
);
for
(
auto
&
call_node_output
:
call_node_outputs
)
{
MS_EXCEPTION_IF_NULL
(
call_node_output
);
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" insert real input:"
<<
call_node_output
->
DebugString
();
(
void
)
new_real_inputs
.
insert
(
call_node_output
);
}
new_real_inputs
=
GetCallRealOutputs
(
item_with_index
.
first
);
continue
;
}
for
(
auto
&
erase_node
:
erase_real_inputs
)
{
(
void
)
real_inputs
.
erase
(
erase_node
);
}
for
(
auto
&
new_real_input
:
new_real_inputs
)
{
(
void
)
real_inputs
.
insert
(
new_real_input
);
}
}
for
(
auto
&
erase_node
:
erase_real_inputs
)
{
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" erase real input:"
<<
erase_node
->
DebugString
();
(
void
)
real_inputs
.
erase
(
erase_node
);
}
for
(
auto
&
new_real_input
:
new_real_inputs
)
{
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" insert real input:"
<<
new_real_input
->
DebugString
();
(
void
)
real_inputs
.
insert
(
new_real_input
);
}
}
}
...
...
mindspore/ccsrc/session/kernel_graph.h
浏览文件 @
5aae0d91
...
...
@@ -103,10 +103,9 @@ class KernelGraph : public FuncGraph {
void
UpdateExecuteKernelStreamLabel
();
// calculate the leaf graph order of root graph
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
GetLeafGraphOrder
();
// update the child graph order of graph
void
UpdateChildGraphOrder
();
// get the child graph of current graph
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
child_graph_order
()
const
{
return
child_graph_order_
;
}
// the child graph of current graph
const
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
&
child_graph_order
()
const
{
return
child_graph_order_
;
}
void
set_child_graph_order
(
const
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
&
order
)
{
child_graph_order_
=
order
;
}
// checkout whether current graph is leaf graph
bool
IsLeafGraph
()
const
;
...
...
@@ -123,6 +122,7 @@ class KernelGraph : public FuncGraph {
// find anf node in graph
std
::
vector
<
CNodePtr
>
FindNodeByPrimitive
(
const
PrimitivePtr
&
primitive
)
const
;
// get real inputs
const
std
::
map
<
AnfNodePtr
,
std
::
set
<
AnfNodePtr
>>
&
real_inputs
()
const
{
return
real_inputs_
;
}
std
::
set
<
AnfNodePtr
>
GetRealInput
(
const
AnfNodePtr
&
parameter
);
void
SetRealInput
(
const
AnfNodePtr
&
parameter
,
const
AnfNodePtr
&
arg
);
// used to dump ir
...
...
@@ -132,6 +132,8 @@ class KernelGraph : public FuncGraph {
void
set_start_label
(
const
CNodePtr
&
start_label
)
{
start_label_
=
start_label
;
}
CNodePtr
get_start_label
()
{
return
start_label_
;
}
void
set_end_goto
(
const
CNodePtr
&
end_goto
)
{
end_goto_
=
end_goto
;
}
CNodePtr
get_end_goto
()
{
return
end_goto_
;
}
private:
// remove value node form graph
...
...
@@ -185,6 +187,7 @@ class KernelGraph : public FuncGraph {
std
::
map
<
AnfNodePtr
,
std
::
set
<
AnfNodePtr
>>
real_inputs_
;
CNodePtr
start_label_
;
CNodePtr
end_goto_
;
};
}
// namespace session
using
KernelGraphPtr
=
std
::
shared_ptr
<
session
::
KernelGraph
>
;
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
5aae0d91
...
...
@@ -147,6 +147,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
MS_LOG
(
INFO
)
<<
"create tensor for output["
<<
anf
->
DebugString
()
<<
"]"
;
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
anf
,
0
);
MS_EXCEPTION_IF_NULL
(
item_with_index
.
first
);
MS_LOG
(
INFO
)
<<
"create tensor for output after visit:"
<<
item_with_index
.
first
->
DebugString
();
// special handle for maketuple
if
(
AnfAlgo
::
CheckPrimitiveType
(
item_with_index
.
first
,
prim
::
kPrimMakeTuple
))
{
auto
cnode
=
item_with_index
.
first
->
cast
<
CNodePtr
>
();
...
...
@@ -479,31 +480,12 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
}
for
(
size_t
input_idx
=
1
;
input_idx
<
cnode
->
inputs
().
size
();
input_idx
++
)
{
auto
anf
=
cnode
->
input
s
()[
input_idx
]
;
auto
anf
=
cnode
->
input
(
input_idx
)
;
MS_EXCEPTION_IF_NULL
(
anf
);
// anf has been created before
if
(
graph
->
GetBackendAnfByFrontAnf
(
anf
)
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
graph
->
GetBackendAnfByFrontAnf
(
anf
));
continue
;
}
else
if
(
anf
->
isa
<
ValueNode
>
())
{
if
(
!
IsValueNode
<
FuncGraph
>
(
anf
))
{
// if input is a common value node,
auto
new_value_node
=
CreateNewValueNode
(
anf
,
graph
);
if
(
new_value_node
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
new_value_node
);
}
}
else
{
// if input is a ValueNode<FuncGraph>
auto
new_value_node
=
CreateValueNodeKernelGraph
(
anf
,
graph
);
if
(
new_value_node
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
new_value_node
);
}
}
continue
;
}
else
if
(
anf
->
isa
<
Parameter
>
())
{
auto
new_parameter
=
CreateNewParameter
(
anf
,
graph
);
cnode_inputs
.
push_back
(
new_parameter
);
continue
;
}
MS_LOG
(
EXCEPTION
)
<<
"Unexpected input["
<<
anf
->
DebugString
()
<<
"]"
;
}
...
...
@@ -613,32 +595,22 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
for
(
const
auto
&
node
:
node_list
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_LOG
(
DEBUG
)
<<
"Start create new cnode, node = "
<<
node
->
DebugString
();
if
(
!
node
->
isa
<
CNode
>
())
{
MS_LOG
(
DEBUG
)
<<
"Node "
<<
node
->
DebugString
()
<<
" is not CNode"
;
if
(
node
->
isa
<
Parameter
>
())
{
(
void
)
CreateNewParameter
(
node
,
graph
.
get
());
continue
;
}
else
if
(
node
->
isa
<
ValueNode
>
())
{
if
(
!
IsValueNode
<
FuncGraph
>
(
node
))
{
// if input is a common value node,
(
void
)
CreateNewValueNode
(
node
,
graph
.
get
());
}
else
{
// if input is a ValueNode<FuncGraph>
auto
child_graph
=
ConstructKernelGraph
(
AnfAlgo
::
GetValueNodeFuncGraph
(
node
));
auto
new_value_node
=
CreateValueNodeKernelGraph
(
node
,
graph
.
get
());
}
continue
;
}
else
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
// recurse control ops: call, partial
auto
attr_input
=
cnode
->
input
(
kAnfPrimitiveIndex
);
MS_EXCEPTION_IF_NULL
(
attr_input
);
if
(
IsValueNode
<
FuncGraph
>
(
attr_input
))
{
// recurse call subgraph
auto
sub_func_graph
=
AnfAlgo
::
GetValueNodeFuncGraph
(
attr_input
);
ConstructKernelGraph
(
sub_func_graph
);
}
else
if
(
IsValueNode
<
Primitive
>
(
attr_input
))
{
auto
prim
=
GetCNodePrimitive
(
node
);
MS_EXCEPTION_IF_NULL
(
prim
);
if
(
prim
->
name
()
==
kPartialOpName
)
{
// recurse partial subgraph
auto
func_graph_node
=
cnode
->
input
(
kAnfPartialFuncGraphIndex
);
MS_EXCEPTION_IF_NULL
(
func_graph_node
);
auto
sub_func_graph
=
AnfAlgo
::
GetValueNodeFuncGraph
(
func_graph_node
);
ConstructKernelGraph
(
sub_func_graph
);
}
}
// create a new cnode object
auto
new_cnode
=
CreateNewCNode
(
cnode
,
graph
.
get
());
MS_EXCEPTION_IF_NULL
(
new_cnode
);
...
...
@@ -650,7 +622,21 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
}
}
}
auto
graph_inputs
=
graph
->
MutableInputs
();
MS_EXCEPTION_IF_NULL
(
graph_inputs
);
graph_inputs
->
clear
();
for
(
auto
&
parameter
:
func_graph
->
parameters
())
{
MS_EXCEPTION_IF_NULL
(
parameter
);
auto
backend_parameter
=
graph
->
GetBackendAnfByFrontAnf
(
parameter
);
if
(
backend_parameter
==
nullptr
)
{
// for example "def f(x,y,z) {return x + y}", parameter z in unused
CreateNewParameterFromParameter
(
parameter
,
false
,
graph
.
get
());
MS_LOG
(
INFO
)
<<
"Can't find parameter:"
<<
parameter
->
DebugString
();
continue
;
}
MS_LOG
(
INFO
)
<<
"graph["
<<
graph
->
graph_id
()
<<
"],parameter:"
<<
parameter
->
DebugString
();
graph_inputs
->
push_back
(
backend_parameter
);
}
MS_EXCEPTION_IF_NULL
(
context_
);
FuncGraphManagerPtr
manager
=
context_
->
manager
();
if
(
manager
)
{
...
...
@@ -716,6 +702,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
const
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
outputs
);
if
(
!
kernel_graph
->
child_graph_order
().
empty
())
{
// use the last child graph output as the root graph output
UpdateOutputs
(
kernel_graph
->
child_graph_order
().
back
(),
outputs
,
input_tensors
);
return
;
}
auto
anf_outputs
=
kernel_graph
->
outputs
();
for
(
auto
&
item
:
anf_outputs
)
{
MS_LOG
(
INFO
)
<<
"update output["
<<
item
->
DebugString
()
<<
"]"
;
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
5aae0d91
...
...
@@ -487,8 +487,7 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
}
void
TraverseGraphMap
(
const
FuncGraphManagerPtr
&
manager_ptr
,
FuncGraphTransaction
*
const
tr
,
const
FuncGraphSet
&
fgs
,
const
FuncGraphManagerPtr
&
manager_ptr
,
FuncGraphTransaction
*
const
tr
,
const
FuncGraphSet
&
fgs
,
const
std
::
function
<
std
::
shared_ptr
<
FuncGraph
>
(
const
PrimitivePtr
,
const
AbstractFunctionPtr
)
>
&
get_prim_graph
)
{
MS_EXCEPTION_IF_NULL
(
manager_ptr
);
MS_EXCEPTION_IF_NULL
(
tr
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录