Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5aae0d91
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5aae0d91
编写于
5月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1459 Insert assign nodes for linking sub graph
Merge pull request !1459 from zhoufeng/link-assign
上级
7f80d028
f868a285
变更
10
展开全部
显示空白变更内容
内联
并排
Showing
10 changed file
with
542 addition
and
264 deletion
+542
-264
mindspore/ccsrc/device/ascend/ascend_label_assign.cc
mindspore/ccsrc/device/ascend/ascend_label_assign.cc
+17
-2
mindspore/ccsrc/device/kernel_runtime.h
mindspore/ccsrc/device/kernel_runtime.h
+1
-1
mindspore/ccsrc/session/ascend_control_parser.cc
mindspore/ccsrc/session/ascend_control_parser.cc
+281
-87
mindspore/ccsrc/session/ascend_control_parser.h
mindspore/ccsrc/session/ascend_control_parser.h
+24
-10
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+93
-39
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+4
-4
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+80
-72
mindspore/ccsrc/session/kernel_graph.h
mindspore/ccsrc/session/kernel_graph.h
+7
-4
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+34
-43
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+1
-2
未找到文件。
mindspore/ccsrc/device/ascend/ascend_label_assign.cc
浏览文件 @
5aae0d91
...
@@ -28,6 +28,9 @@ namespace device {
...
@@ -28,6 +28,9 @@ namespace device {
namespace
ascend
{
namespace
ascend
{
static
void
UpdateLabelGoto
(
NotNull
<
CNodePtr
>
node
)
{
static
void
UpdateLabelGoto
(
NotNull
<
CNodePtr
>
node
)
{
if
(
AnfAlgo
::
HasNodeAttr
(
kAttrLabelIndex
,
node
))
{
return
;
}
if
(
node
->
size
()
<=
kLabelGotoLabelId
)
{
if
(
node
->
size
()
<=
kLabelGotoLabelId
)
{
MS_LOG
(
EXCEPTION
)
<<
"Node "
<<
node
->
DebugString
()
<<
" has invalid input size "
<<
node
->
size
();
MS_LOG
(
EXCEPTION
)
<<
"Node "
<<
node
->
DebugString
()
<<
" has invalid input size "
<<
node
->
size
();
}
}
...
@@ -42,6 +45,9 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) {
...
@@ -42,6 +45,9 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) {
}
}
static
void
UpdateLabelSwitch
(
NotNull
<
CNodePtr
>
node
)
{
static
void
UpdateLabelSwitch
(
NotNull
<
CNodePtr
>
node
)
{
if
(
AnfAlgo
::
HasNodeAttr
(
kAttrLabelIndex
,
node
))
{
return
;
}
if
(
node
->
size
()
<=
kLabelGotoLabelId
)
{
if
(
node
->
size
()
<=
kLabelGotoLabelId
)
{
MS_LOG
(
EXCEPTION
)
<<
"Node "
<<
node
->
DebugString
()
<<
" has invalid input size "
<<
node
->
size
();
MS_LOG
(
EXCEPTION
)
<<
"Node "
<<
node
->
DebugString
()
<<
" has invalid input size "
<<
node
->
size
();
}
}
...
@@ -69,9 +75,12 @@ static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>
...
@@ -69,9 +75,12 @@ static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>
if
(
memo
->
find
(
graph
.
get
())
!=
memo
->
end
())
{
if
(
memo
->
find
(
graph
.
get
())
!=
memo
->
end
())
{
return
;
return
;
}
}
memo
->
insert
(
graph
.
get
());
MS_LOG
(
INFO
)
<<
"Assign label for "
<<
graph
->
ToString
();
MS_LOG
(
INFO
)
<<
"Assign label for "
<<
graph
->
ToString
();
auto
nodes
=
TopoSort
(
graph
->
get_return
());
graph
->
SetExecOrderByDefault
();
auto
nodes
=
graph
->
execution_order
();
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
continue
;
...
@@ -97,9 +106,15 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap
...
@@ -97,9 +106,15 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap
if
(
memo
->
find
(
graph
.
get
())
!=
memo
->
end
())
{
if
(
memo
->
find
(
graph
.
get
())
!=
memo
->
end
())
{
return
;
return
;
}
}
memo
->
insert
(
graph
.
get
());
MS_LOG
(
INFO
)
<<
"Process label goto/switch for "
<<
graph
->
ToString
();
MS_LOG
(
INFO
)
<<
"Process label goto/switch for "
<<
graph
->
ToString
();
auto
nodes
=
TopoSort
(
graph
->
get_return
());
graph
->
SetExecOrderByDefault
();
auto
nodes
=
graph
->
execution_order
();
auto
end_goto
=
graph
->
get_end_goto
();
if
(
end_goto
!=
nullptr
)
{
nodes
.
push_back
(
end_goto
);
}
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
continue
;
...
...
mindspore/ccsrc/device/kernel_runtime.h
浏览文件 @
5aae0d91
...
@@ -53,6 +53,7 @@ class KernelRuntime {
...
@@ -53,6 +53,7 @@ class KernelRuntime {
virtual
bool
GenTask
(
const
session
::
KernelGraph
*
graph
);
virtual
bool
GenTask
(
const
session
::
KernelGraph
*
graph
);
bool
LaunchKernel
(
const
session
::
KernelGraph
*
graph
);
bool
LaunchKernel
(
const
session
::
KernelGraph
*
graph
);
virtual
void
AssignStaticMemoryInput
(
const
session
::
KernelGraph
*
graph
);
virtual
void
AssignStaticMemoryInput
(
const
session
::
KernelGraph
*
graph
);
virtual
void
AssignStaticMemoryValueNode
(
session
::
KernelGraph
*
graph
);
#ifdef ENABLE_DUMP_E2E
#ifdef ENABLE_DUMP_E2E
DumpConfPtr
GetDumpConf
();
DumpConfPtr
GetDumpConf
();
...
@@ -67,7 +68,6 @@ class KernelRuntime {
...
@@ -67,7 +68,6 @@ class KernelRuntime {
TypeId
type_id
)
=
0
;
TypeId
type_id
)
=
0
;
virtual
bool
SyncStream
()
=
0
;
virtual
bool
SyncStream
()
=
0
;
void
AssignStaticMemory
(
session
::
KernelGraph
*
graph
);
void
AssignStaticMemory
(
session
::
KernelGraph
*
graph
);
void
AssignStaticMemoryValueNode
(
session
::
KernelGraph
*
graph
);
void
AssignDynamicMemory
(
session
::
KernelGraph
*
graph
);
void
AssignDynamicMemory
(
session
::
KernelGraph
*
graph
);
void
ReuseAssignDynamicMemory
(
session
::
KernelGraph
*
graph
);
void
ReuseAssignDynamicMemory
(
session
::
KernelGraph
*
graph
);
void
AssignNodeOutputMem
(
int
flag
,
const
AnfNodePtr
&
node
,
int
index
);
void
AssignNodeOutputMem
(
int
flag
,
const
AnfNodePtr
&
node
,
int
index
);
...
...
mindspore/ccsrc/session/ascend_control_parser.cc
浏览文件 @
5aae0d91
此差异已折叠。
点击以展开。
mindspore/ccsrc/session/ascend_control_parser.h
浏览文件 @
5aae0d91
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
#include <set>
#include <set>
#include <map>
#include <vector>
#include <vector>
#include <tuple>
#include <tuple>
#include "session/kernel_graph.h"
#include "session/kernel_graph.h"
...
@@ -28,31 +29,44 @@ namespace session {
...
@@ -28,31 +29,44 @@ namespace session {
class
AscendControlParser
{
class
AscendControlParser
{
public:
public:
static
void
ChildGraphDataAssign
(
const
std
::
map
<
uint32_t
,
KernelGraphPtr
>
&
graph_id_map
);
static
void
LinkGraph
(
NotNull
<
KernelGraphPtr
>
kg
);
static
void
LinkGraph
(
NotNull
<
KernelGraphPtr
>
kg
);
static
void
InsertDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
attch_node
);
static
void
InsertDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
attch_node
);
static
void
InsertControlDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
first_node
,
static
void
InsertControlDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
first_node
,
NotNull
<
AnfNodePtr
>
second_node
);
NotNull
<
AnfNodePtr
>
second_node
);
static
void
ExecutorValidate
(
NotNull
<
KernelGraphPtr
>
root_graph
);
static
void
UpdateChildGraphOrder
(
NotNull
<
KernelGraphPtr
>
kg
);
private:
private:
static
NotNull
<
CNodePtr
>
ProcessKernelGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
last_node
,
static
NotNull
<
CNodePtr
>
ProcessKernelGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
last_node
,
const
CNodePtr
&
last_label
,
const
VectorRef
&
args
,
const
CNodePtr
&
last_label
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
void
RecurseCall
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
const
CNodePtr
&
next_node
,
static
void
RecurseCall
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
const
CNodePtr
&
next_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
void
RecurseSwitch
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
static
void
RecurseSwitch
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
const
CNodePtr
&
next_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
void
RecurseSwitchLayer
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
static
void
RecurseSwitchLayer
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
const
CNodePtr
&
next_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
std
::
vector
<
CNodePtr
>
GetCNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
in
);
static
void
LinkParentGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
from_graph_call_node
,
static
void
LinkParentGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
from_graph_call_node
,
const
CNodePtr
&
last_label
,
const
VectorRef
&
args
);
const
CNodePtr
&
last_label
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
void
SetSubGraphInput
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
from_graph_call_node
,
static
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
>
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
);
const
VectorRef
&
args
);
static
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
,
VectorRef
>
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
);
static
void
LinkArgsToParam
(
NotNull
<
KernelGraphPtr
>
to_graph
,
NotNull
<
KernelGraphPtr
>
target_graph
,
NotNull
<
AnfNodePtr
>
arg
,
NotNull
<
AnfNodePtr
>
param
);
static
NotNull
<
AnfNodePtr
>
GetRealInput
(
NotNull
<
KernelGraphPtr
>
from_graph
,
NotNull
<
KernelGraphPtr
>
to_graph
,
NotNull
<
AnfNodePtr
>
param
);
static
void
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
static
void
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
static
size_t
SetChildGraphInput
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
node
,
size_t
input_index
);
static
CNodePtr
GetNextRealKernel
(
std
::
vector
<
CNodePtr
>
list
,
size_t
start
);
// root graph order
static
std
::
tuple
<
std
::
map
<
uint32_t
,
CNodePtr
>
,
std
::
map
<
CNodePtr
,
std
::
vector
<
uint32_t
>>>
GetLabelNode
(
const
std
::
vector
<
CNodePtr
>
&
nodes
);
static
bool
CheckLabelIndex
(
uint32_t
order_index
,
uint32_t
label_index
,
const
CNodePtr
&
cnode
,
NotNull
<
KernelGraphPtr
>
graph
);
static
std
::
vector
<
CNodePtr
>
RecurseGraph
(
const
CNodePtr
&
cur_label_goto
,
const
CNodePtr
&
end_label_goto
,
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
constexpr
size_t
kCNodePrim
=
0
;
static
constexpr
size_t
kCNodePrim
=
0
;
static
constexpr
size_t
kCNodeCallArg
=
1
;
static
constexpr
size_t
kCNodeCallArg
=
1
;
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
5aae0d91
...
@@ -177,10 +177,6 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
...
@@ -177,10 +177,6 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
for
(
size_t
i
=
0
;
i
<
cnodes
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
cnodes
.
size
();
i
++
)
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
cnodes
[
i
],
prim
::
kPrimCall
)
&&
!
AnfAlgo
::
IsSwitchCall
(
cnodes
[
i
]))
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
cnodes
[
i
],
prim
::
kPrimCall
)
&&
!
AnfAlgo
::
IsSwitchCall
(
cnodes
[
i
]))
{
auto
call_kernel_graph
=
AnfAlgo
::
GetCallNodeKernelGraph
(
cnodes
[
i
]);
auto
call_kernel_graph
=
AnfAlgo
::
GetCallNodeKernelGraph
(
cnodes
[
i
]);
// if graph is the true branch of while,no need split graph
if
(
call_kernel_graph
.
size
()
==
1
&&
call_kernel_graph
[
0
]
==
cur_graph
.
parent_graph
())
{
continue
;
}
auto
prev_call_list
=
std
::
vector
<
CNodePtr
>
(
cnodes
.
begin
()
+
after_call_index
,
cnodes
.
begin
()
+
i
);
auto
prev_call_list
=
std
::
vector
<
CNodePtr
>
(
cnodes
.
begin
()
+
after_call_index
,
cnodes
.
begin
()
+
i
);
auto
call_list
=
std
::
vector
<
CNodePtr
>
(
1
,
cnodes
[
i
]);
auto
call_list
=
std
::
vector
<
CNodePtr
>
(
1
,
cnodes
[
i
]);
after_call_index
=
i
+
1
;
after_call_index
=
i
+
1
;
...
@@ -195,9 +191,9 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
...
@@ -195,9 +191,9 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
void
UpdateRealInput
(
KernelGraph
*
graph
)
{
static
void
UpdateRealInput
(
KernelGraph
*
graph
)
{
auto
call_nodes
=
graph
->
FindNodeByPrimitive
(
prim
::
kPrimCall
);
auto
call_nodes
=
graph
->
FindNodeByPrimitive
(
prim
::
kPrimCall
);
auto
bind_call_
partial
_with_parameter
=
[
&
](
const
std
::
vector
<
AnfNodePtr
>
&
parameters
,
auto
bind_call_
arg
_with_parameter
=
[
&
](
const
std
::
vector
<
AnfNodePtr
>
&
parameters
,
const
std
::
vector
<
AnfNodePtr
>
&
args
,
KernelGraph
*
child_graph
)
->
void
{
const
std
::
vector
<
AnfNodePtr
>
&
args
,
KernelGraph
*
child_graph
)
->
void
{
MS_EXCEPTION_IF_NULL
(
child_graph
);
MS_EXCEPTION_IF_NULL
(
child_graph
);
MS_LOG
(
INFO
)
<<
"start bind parameter of child graph:"
<<
child_graph
->
graph_id
();
MS_LOG
(
INFO
)
<<
"start bind parameter of child graph:"
<<
child_graph
->
graph_id
();
...
@@ -208,8 +204,21 @@ void UpdateRealInput(KernelGraph *graph) {
...
@@ -208,8 +204,21 @@ void UpdateRealInput(KernelGraph *graph) {
MS_LOG
(
EXCEPTION
)
<<
"graph:"
<<
child_graph
->
graph_id
()
<<
" parameters size:"
<<
parameters
.
size
()
MS_LOG
(
EXCEPTION
)
<<
"graph:"
<<
child_graph
->
graph_id
()
<<
" parameters size:"
<<
parameters
.
size
()
<<
" and args size:"
<<
args
.
size
()
<<
" not equal!"
;
<<
" and args size:"
<<
args
.
size
()
<<
" not equal!"
;
}
}
child_graph
->
SetExecOrderByDefault
();
for
(
size_t
i
=
0
;
i
<
parameters
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
parameters
.
size
();
i
++
)
{
MS_LOG
(
INFO
)
<<
"bind paramreter:"
<<
parameters
[
i
]
->
DebugString
()
<<
" ,arg:"
<<
args
[
i
]
->
DebugString
();
if
(
args
[
i
]
==
parameters
[
i
])
{
child_graph
->
SetRealInput
(
parameters
[
i
],
args
[
i
]);
MS_LOG
(
INFO
)
<<
"Parameter and arg are same"
;
continue
;
}
// if arg is a parameter ,then reuse this parameter
if
(
args
[
i
]
->
isa
<
Parameter
>
())
{
MS_LOG
(
INFO
)
<<
"Parameter:"
<<
parameters
[
i
]
->
DebugString
()
<<
" of graph:"
<<
child_graph
->
graph_id
()
<<
" reuse parameter:"
<<
args
[
i
]
->
DebugString
()
<<
" of graph:"
<<
AnfAlgo
::
GetGraphId
(
args
[
i
].
get
());
child_graph
->
ReplaceNode
(
parameters
[
i
],
args
[
i
]);
continue
;
}
child_graph
->
SetRealInput
(
parameters
[
i
],
args
[
i
]);
child_graph
->
SetRealInput
(
parameters
[
i
],
args
[
i
]);
}
}
};
};
...
@@ -218,9 +227,10 @@ void UpdateRealInput(KernelGraph *graph) {
...
@@ -218,9 +227,10 @@ void UpdateRealInput(KernelGraph *graph) {
auto
child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
call_node
);
auto
child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
call_node
);
if
(
child_graphs
.
size
()
==
1
)
{
if
(
child_graphs
.
size
()
==
1
)
{
MS_EXCEPTION_IF_NULL
(
child_graphs
[
0
]);
MS_EXCEPTION_IF_NULL
(
child_graphs
[
0
]);
bind_call_partial_with_parameter
(
std
::
vector
<
AnfNodePtr
>
real_args
=
child_graphs
[
0
]
->
inputs
(),
std
::
vector
<
AnfNodePtr
>
(
call_node
->
inputs
().
begin
()
+
2
,
call_node
->
inputs
().
end
()),
std
::
vector
<
AnfNodePtr
>
(
call_node
->
inputs
().
begin
()
+
2
,
call_node
->
inputs
().
end
());
child_graphs
[
0
].
get
());
std
::
vector
<
AnfNodePtr
>
child_inputs
=
child_graphs
[
0
]
->
inputs
();
bind_call_arg_with_parameter
(
child_inputs
,
real_args
,
child_graphs
[
0
].
get
());
call_node
->
set_inputs
(
std
::
vector
<
AnfNodePtr
>
(
call_node
->
inputs
().
begin
(),
call_node
->
inputs
().
begin
()
+
2
));
call_node
->
set_inputs
(
std
::
vector
<
AnfNodePtr
>
(
call_node
->
inputs
().
begin
(),
call_node
->
inputs
().
begin
()
+
2
));
}
else
if
(
child_graphs
.
size
()
==
2
)
{
}
else
if
(
child_graphs
.
size
()
==
2
)
{
auto
get_partial_args
=
[
&
](
size_t
input_index
)
->
std
::
vector
<
AnfNodePtr
>
{
auto
get_partial_args
=
[
&
](
size_t
input_index
)
->
std
::
vector
<
AnfNodePtr
>
{
...
@@ -237,8 +247,8 @@ void UpdateRealInput(KernelGraph *graph) {
...
@@ -237,8 +247,8 @@ void UpdateRealInput(KernelGraph *graph) {
std
::
vector
<
AnfNodePtr
>
(
partial_cnode
->
inputs
().
begin
(),
partial_cnode
->
inputs
().
begin
()
+
2
));
std
::
vector
<
AnfNodePtr
>
(
partial_cnode
->
inputs
().
begin
(),
partial_cnode
->
inputs
().
begin
()
+
2
));
return
ret
;
return
ret
;
};
};
bind_call_
partial
_with_parameter
(
child_graphs
[
0
]
->
inputs
(),
get_partial_args
(
2
),
child_graphs
[
0
].
get
());
bind_call_
arg
_with_parameter
(
child_graphs
[
0
]
->
inputs
(),
get_partial_args
(
2
),
child_graphs
[
0
].
get
());
bind_call_
partial
_with_parameter
(
child_graphs
[
1
]
->
inputs
(),
get_partial_args
(
3
),
child_graphs
[
1
].
get
());
bind_call_
arg
_with_parameter
(
child_graphs
[
1
]
->
inputs
(),
get_partial_args
(
3
),
child_graphs
[
1
].
get
());
}
}
}
}
}
}
...
@@ -248,6 +258,11 @@ void RecurseToUpdateCallRealInput(KernelGraph *graph) {
...
@@ -248,6 +258,11 @@ void RecurseToUpdateCallRealInput(KernelGraph *graph) {
MS_LOG
(
INFO
)
<<
"start graph id:"
<<
graph
->
graph_id
();
MS_LOG
(
INFO
)
<<
"start graph id:"
<<
graph
->
graph_id
();
graph
->
UpdateCallRealInput
();
graph
->
UpdateCallRealInput
();
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
if
(
child_graph
==
graph
->
parent_graph
())
{
MS_LOG
(
INFO
)
<<
"Child graph:"
<<
child_graph
->
graph_id
()
<<
",parent graph:"
<<
graph
->
parent_graph
()
->
graph_id
();
continue
;
}
RecurseToUpdateCallRealInput
(
child_graph
.
get
());
RecurseToUpdateCallRealInput
(
child_graph
.
get
());
}
}
}
}
...
@@ -265,31 +280,31 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
...
@@ -265,31 +280,31 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
GraphId
AscendSession
::
CompileGraph
(
NotNull
<
FuncGraphPtr
>
func_graph
)
{
GraphId
AscendSession
::
CompileGraph
(
NotNull
<
FuncGraphPtr
>
func_graph
)
{
MS_LOG
(
INFO
)
<<
"start"
;
MS_LOG
(
INFO
)
<<
"start"
;
auto
graph
=
ConstructKernelGraph
(
func_graph
);
auto
graph
=
ConstructKernelGraph
(
func_graph
);
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
// split switch
// split switch
SplitGraphs
(
graph
);
SplitGraphs
(
graph
);
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
// insert goto labels and label_sets
// insert goto labels and label_sets
LinkChildGraphs
(
NOT_NULL
(
graph
));
LinkChildGraphs
(
NOT_NULL
(
graph
));
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
// resource initialize
// resource initialize
InitRuntimeResource
();
InitRuntimeResource
();
// assign label
// assign label
AssignLabel
(
NOT_NULL
(
graph
));
AssignLabel
(
NOT_NULL
(
graph
));
if
(
!
graph
->
executable
())
{
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
return
graph
->
graph_id
();
// recurse compile child graph
}
RecurseCompileGraph
(
graph
);
for
(
auto
iter
:
graphs_
)
{
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
if
(
iter
.
second
==
graph
)
{
// root graph valiate,include genearte execute order and so on
MS_LOG
(
INFO
)
<<
"Entry graph "
<<
graph
->
ToString
()
<<
" graph id "
<<
graph
->
graph_id
();
RootGraphExecutorValidate
(
NOT_NULL
(
graph
));
final_graph_id_
=
graph
->
graph_id
();
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
}
MS_LOG
(
INFO
)
<<
"CompileChildGraph "
<<
iter
.
second
->
ToString
();
CompileChildGraph
(
iter
.
second
);
}
// adjust kernel
// adjust kernel
AdjustKernel
(
graph
);
AdjustKernel
(
graph
);
// root graph valiate,include genearte execute order and so on
MS_LOG
(
INFO
)
<<
"graph input size:"
<<
graph
->
inputs
().
size
();
RootGraphExecutorValidate
(
graph
.
get
());
// assign stream
// assign stream
AssignStream
(
graph
);
AssignStream
(
graph
);
// build kernel
BuildKernel
(
graph
);
// alloc mem
// alloc mem
MemoryAlloc
(
graph
.
get
());
MemoryAlloc
(
graph
.
get
());
// task generate
// task generate
...
@@ -365,6 +380,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
...
@@ -365,6 +380,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
void
AscendSession
::
CompileChildGraph
(
const
KernelGraphPtr
&
child_graph
)
{
void
AscendSession
::
CompileChildGraph
(
const
KernelGraphPtr
&
child_graph
)
{
MS_EXCEPTION_IF_NULL
(
child_graph
);
MS_EXCEPTION_IF_NULL
(
child_graph
);
MS_LOG
(
INFO
)
<<
"CompileChildGraph "
<<
child_graph
->
ToString
();
opt
::
AscendBackendIRFusionOptimization
(
child_graph
);
opt
::
AscendBackendIRFusionOptimization
(
child_graph
);
// select kernel build info
// select kernel build info
SelectKernel
(
*
child_graph
);
SelectKernel
(
*
child_graph
);
...
@@ -376,12 +392,14 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
...
@@ -376,12 +392,14 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id_
);
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id_
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
runtime_instance
->
AssignStaticMemoryInput
(
child_graph
.
get
());
runtime_instance
->
AssignStaticMemoryInput
(
child_graph
.
get
());
runtime_instance
->
AssignStaticMemoryValueNode
(
child_graph
.
get
());
}
}
void
AscendSession
::
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
void
AscendSession
::
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
const
outputs
)
{
VectorRef
*
const
outputs
)
{
MS_LOG
(
INFO
)
<<
"start"
;
MS_LOG
(
INFO
)
<<
"start"
;
auto
kernel_graph
=
GetGraph
(
graph_id
);
auto
kernel_graph
=
GetGraph
(
graph_id
);
DumpIR
(
"./run_graph.ir"
,
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
// if none of child graph and no anf output exists
// if none of child graph and no anf output exists
if
(
!
kernel_graph
->
executable
())
{
if
(
!
kernel_graph
->
executable
())
{
...
@@ -1378,10 +1396,10 @@ void AscendSession::SyncInitialTenosrToDevice() {
...
@@ -1378,10 +1396,10 @@ void AscendSession::SyncInitialTenosrToDevice() {
}
}
}
}
KernelGraphPtr
AscendSession
::
ConstructSplitedGraph
(
const
KernelGraphPtr
&
new_kernel_graph
,
std
::
vector
<
AnfNodePtr
>
AscendSession
::
ConstructSplitedGraph
(
const
KernelGraphPtr
&
new_kernel_graph
,
const
std
::
vector
<
CNodePtr
>
&
list
)
{
const
std
::
vector
<
CNodePtr
>
&
list
)
{
MS_EXCEPTION_IF_NULL
(
new_kernel_graph
);
MS_EXCEPTION_IF_NULL
(
new_kernel_graph
);
MS_LOG
(
INFO
)
<<
"start
split
kernel graph:"
<<
new_kernel_graph
->
graph_id
();
MS_LOG
(
INFO
)
<<
"start
contruct splited
kernel graph:"
<<
new_kernel_graph
->
graph_id
();
// count the output of every anf node
// count the output of every anf node
std
::
set
<
AnfNodePtr
>
has_output_nodes
;
std
::
set
<
AnfNodePtr
>
has_output_nodes
;
for
(
auto
&
anf_node
:
list
)
{
for
(
auto
&
anf_node
:
list
)
{
...
@@ -1390,21 +1408,23 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke
...
@@ -1390,21 +1408,23 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke
}
}
}
}
MS_LOG
(
INFO
)
<<
"Construct input of kernel graph:"
<<
new_kernel_graph
->
graph_id
();
MS_LOG
(
INFO
)
<<
"Construct input of kernel graph:"
<<
new_kernel_graph
->
graph_id
();
std
::
vector
<
AnfNodePtr
>
call_node_inputs
;
auto
graph_inputs
=
new_kernel_graph
->
MutableInputs
();
MS_EXCEPTION_IF_NULL
(
graph_inputs
);
// create new parameter from cnode
// create new parameter from cnode
for
(
auto
&
anf_node
:
list
)
{
for
(
auto
&
anf_node
:
list
)
{
auto
cnode
=
anf_node
->
cast
<
CNodePtr
>
();
auto
cnode
=
anf_node
->
cast
<
CNodePtr
>
();
for
(
size_t
input_idx
=
1
;
input_idx
<
cnode
->
inputs
().
size
();
input_idx
++
)
{
for
(
size_t
input_idx
=
1
;
input_idx
<
cnode
->
inputs
().
size
();
input_idx
++
)
{
auto
input
=
cnode
->
inputs
()[
input_idx
];
auto
input
=
cnode
->
inputs
()[
input_idx
];
MS_EXCEPTION_IF_NULL
(
input
);
MS_EXCEPTION_IF_NULL
(
input
);
if
(
!
input
->
isa
<
CNode
>
())
{
if
(
input
->
isa
<
Parameter
>
())
{
graph_inputs
->
push_back
(
input
);
cnode
->
set_input
(
input_idx
,
input
);
cnode
->
set_input
(
input_idx
,
input
);
continue
;
}
else
if
(
AnfAlgo
::
GetGraphId
(
input
.
get
())
!=
new_kernel_graph
->
graph_id
())
{
}
if
(
AnfAlgo
::
GetGraphId
(
input
.
get
())
!=
new_kernel_graph
->
graph_id
())
{
auto
new_parameter
=
CreateNewParameterFromCNode
(
input
,
true
,
new_kernel_graph
.
get
());
auto
new_parameter
=
CreateNewParameterFromCNode
(
input
,
true
,
new_kernel_graph
.
get
());
cnode
->
set_input
(
input_idx
,
new_parameter
);
cnode
->
set_input
(
input_idx
,
new_parameter
);
new_kernel_graph
->
SetRealInput
(
new_parameter
,
input
);
}
}
call_node_inputs
.
push_back
(
input
);
}
}
}
}
MS_LOG
(
INFO
)
<<
"Construct output of kernel graph:"
<<
new_kernel_graph
->
graph_id
();
MS_LOG
(
INFO
)
<<
"Construct output of kernel graph:"
<<
new_kernel_graph
->
graph_id
();
...
@@ -1424,7 +1444,7 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke
...
@@ -1424,7 +1444,7 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke
new_kernel_graph
->
set_output
(
new_kernel_graph
->
NewCNode
(
make_tuple_inputs
));
new_kernel_graph
->
set_output
(
new_kernel_graph
->
NewCNode
(
make_tuple_inputs
));
}
}
MS_LOG
(
INFO
)
<<
"end"
;
MS_LOG
(
INFO
)
<<
"end"
;
return
new_kernel_graph
;
return
call_node_inputs
;
}
}
void
AscendSession
::
SplitGraphs
(
const
KernelGraphPtr
&
root_graph
)
{
void
AscendSession
::
SplitGraphs
(
const
KernelGraphPtr
&
root_graph
)
{
...
@@ -1438,7 +1458,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
...
@@ -1438,7 +1458,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
graph
);
auto
apply_list
=
GetCNodes
(
TopoSort
(
graph
->
get_return
()));
auto
apply_list
=
GetCNodes
(
TopoSort
(
graph
->
get_return
()));
// update the root graph child graph order
// update the root graph child graph order
graph
->
UpdateChildGraphOrder
(
);
AscendControlParser
::
UpdateChildGraphOrder
(
NOT_NULL
(
graph
)
);
// get child list from current graph
// get child list from current graph
std
::
vector
<
std
::
vector
<
CNodePtr
>>
child_graph_lists
=
GetChildList
(
*
graph
,
apply_list
);
std
::
vector
<
std
::
vector
<
CNodePtr
>>
child_graph_lists
=
GetChildList
(
*
graph
,
apply_list
);
auto
bind_new_call_to_new_graph
=
[
&
](
std
::
vector
<
CNodePtr
>
child_graph_list
)
->
AnfNodePtr
{
auto
bind_new_call_to_new_graph
=
[
&
](
std
::
vector
<
CNodePtr
>
child_graph_list
)
->
AnfNodePtr
{
...
@@ -1457,7 +1477,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
...
@@ -1457,7 +1477,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
for
(
auto
&
child_graph_node
:
child_graph_list
)
{
for
(
auto
&
child_graph_node
:
child_graph_list
)
{
AnfAlgo
::
SetGraphId
(
child_graph
->
graph_id
(),
child_graph_node
.
get
());
AnfAlgo
::
SetGraphId
(
child_graph
->
graph_id
(),
child_graph_node
.
get
());
}
}
ConstructSplitedGraph
(
child_graph
,
child_graph_list
);
auto
call_node_args
=
ConstructSplitedGraph
(
child_graph
,
child_graph_list
);
std
::
copy
(
call_node_args
.
begin
(),
call_node_args
.
end
(),
std
::
back_inserter
(
new_call_input
));
auto
new_call
=
graph
->
NewCNode
(
new_call_input
);
auto
new_call
=
graph
->
NewCNode
(
new_call_input
);
AnfAlgo
::
SetNodeAttr
(
"graph id"
,
MakeValue
(
graph
->
graph_id
()),
new_call
);
AnfAlgo
::
SetNodeAttr
(
"graph id"
,
MakeValue
(
graph
->
graph_id
()),
new_call
);
return
new_call
;
return
new_call
;
...
@@ -1466,26 +1487,59 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
...
@@ -1466,26 +1487,59 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
std
::
list
<
AnfNodePtr
>
depend_input
=
{};
std
::
list
<
AnfNodePtr
>
depend_input
=
{};
for
(
size_t
call_index
=
0
;
call_index
<
child_graph_lists
.
size
();
call_index
++
)
{
for
(
size_t
call_index
=
0
;
call_index
<
child_graph_lists
.
size
();
call_index
++
)
{
auto
call_node
=
bind_new_call_to_new_graph
(
child_graph_lists
[
call_index
]);
auto
call_node
=
bind_new_call_to_new_graph
(
child_graph_lists
[
call_index
]);
MS_EXCEPTION_IF_NULL
(
call_node
);
// if call node is the last call of true graph,no need create child graph after that
auto
child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
call_node
->
cast
<
CNodePtr
>
());
depend_input
.
push_front
(
call_node
);
depend_input
.
push_front
(
call_node
);
if
(
child_graphs
.
size
()
==
1
&&
child_graphs
[
0
]
==
graph
->
parent_graph
())
{
break
;
}
}
}
depend_input
.
push_front
(
graph
->
NewValueNode
(
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimDepend
->
name
()))));
depend_input
.
push_front
(
graph
->
NewValueNode
(
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimDepend
->
name
()))));
auto
depend
=
graph
->
NewCNode
(
std
::
vector
<
AnfNodePtr
>
(
depend_input
.
begin
(),
depend_input
.
end
()));
auto
depend
=
graph
->
NewCNode
(
std
::
vector
<
AnfNodePtr
>
(
depend_input
.
begin
(),
depend_input
.
end
()));
auto
new_return_primitive
=
auto
new_return_primitive
=
graph
->
NewValueNode
(
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimReturn
->
name
())));
graph
->
NewValueNode
(
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimReturn
->
name
())));
graph
->
set_return
(
graph
->
NewCNode
({
new_return_primitive
,
depend
}));
graph
->
set_return
(
graph
->
NewCNode
({
new_return_primitive
,
depend
}));
AnfNodePtr
pre_call_node
=
nullptr
;
AnfNodePtr
cur_call_node
=
nullptr
;
auto
iter
=
depend_input
.
begin
();
for
(
++
iter
;
iter
!=
depend_input
.
end
();
++
iter
)
{
pre_call_node
=
cur_call_node
;
cur_call_node
=
*
iter
;
if
(
pre_call_node
!=
nullptr
&&
cur_call_node
!=
nullptr
)
{
AscendControlParser
::
InsertControlDependToGraph
(
NOT_NULL
(
graph
),
NOT_NULL
(
cur_call_node
),
NOT_NULL
(
pre_call_node
));
}
}
}
}
graph
->
UpdateChildGraphOrder
(
);
AscendControlParser
::
UpdateChildGraphOrder
(
NOT_NULL
(
graph
)
);
UpdateRealInput
(
graph
.
get
());
UpdateRealInput
(
graph
.
get
());
auto
graph_name
=
std
::
string
(
"./kernel-graph-"
).
append
(
std
::
to_string
(
graph
->
graph_id
()));
auto
graph_name
=
std
::
string
(
"./kernel-graph-"
).
append
(
std
::
to_string
(
graph
->
graph_id
()));
DumpIR
(
graph_name
,
graph
);
DumpIR
(
graph_name
,
graph
);
MS_LOG
(
INFO
)
<<
"split graph["
<<
graph
->
graph_id
()
<<
"] end"
;
MS_LOG
(
INFO
)
<<
"split graph["
<<
graph
->
graph_id
()
<<
"] end"
;
// recurse to split child graph
// recurse to split child graph
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
if
(
child_graph
!=
graph
->
parent_graph
())
{
SplitGraph
(
child_graph
);
SplitGraph
(
child_graph
);
}
}
}
}
}
void
AscendSession
::
LinkChildGraphs
(
NotNull
<
KernelGraphPtr
>
graph
)
{
AscendControlParser
::
LinkGraph
(
graph
);
}
void
AscendSession
::
LinkChildGraphs
(
NotNull
<
KernelGraphPtr
>
graph
)
{
AscendControlParser
::
LinkGraph
(
graph
);
}
void
AscendSession
::
RootGraphExecutorValidate
(
NotNull
<
KernelGraphPtr
>
graph
)
{
AscendControlParser
::
ExecutorValidate
(
graph
);
}
void
AscendSession
::
RecurseCompileGraph
(
const
KernelGraphPtr
&
graph
)
{
CompileChildGraph
(
graph
);
for
(
auto
child_graph
:
graph
->
child_graph_order
())
{
if
(
child_graph
==
graph
->
parent_graph
())
{
continue
;
}
RecurseCompileGraph
(
child_graph
);
}
}
}
// namespace session
}
// namespace session
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
5aae0d91
...
@@ -104,10 +104,10 @@ class AscendSession : public SessionBasic {
...
@@ -104,10 +104,10 @@ class AscendSession : public SessionBasic {
void
SelectKernelGraphKernel
(
const
KernelGraph
&
graph
)
{}
void
SelectKernelGraphKernel
(
const
KernelGraph
&
graph
)
{}
void
ConvertPredictModel
(
const
KernelGraphPtr
graph
)
{}
void
ConvertPredictModel
(
const
KernelGraphPtr
graph
)
{}
void
HardwareOptimizeGraphs
(
const
KernelGraphPtr
graph
)
{}
void
HardwareOptimizeGraphs
(
const
KernelGraphPtr
graph
)
{}
void
RootGraphExecutorValidate
(
KernelGraph
*
graph
)
{}
void
RootGraphExecutorValidate
(
NotNull
<
KernelGraphPtr
>
graph
);
void
RecurseUpdateAllChildGraohOrder
(
KernelGraph
*
root_graph
);
std
::
vector
<
AnfNodePtr
>
ConstructSplitedGraph
(
const
KernelGraphPtr
&
new_kernel_graph
,
KernelGraphPtr
ConstructSplitedGraph
(
const
KernelGraphPtr
&
new_kernel_graph
,
const
std
::
vector
<
CNodePtr
>
&
list
);
const
std
::
vector
<
CNodePtr
>
&
list
);
void
ChildGraphCommunicationDecrease
(
std
::
vector
<
std
::
vector
<
AnfNodePtr
>>
*
anf_node_lists
);
void
RecurseCompileGraph
(
const
KernelGraphPtr
&
graph
);
// merge execution order list of child graphs
// merge execution order list of child graphs
void
MergeGraphExecOrder
();
void
MergeGraphExecOrder
();
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
5aae0d91
...
@@ -165,6 +165,21 @@ void KernelGraph::SetExecOrderByDefault() {
...
@@ -165,6 +165,21 @@ void KernelGraph::SetExecOrderByDefault() {
}
}
}
}
CheckLoop
();
CheckLoop
();
// resort start label / end goto
std
::
vector
<
CNodePtr
>
re_order
;
if
(
start_label_
!=
nullptr
)
{
re_order
.
push_back
(
start_label_
);
}
for
(
auto
&
node
:
execution_order_
)
{
if
(
node
==
start_label_
||
node
==
end_goto_
)
{
continue
;
}
re_order
.
push_back
(
node
);
}
if
(
end_goto_
!=
nullptr
)
{
re_order
.
push_back
(
end_goto_
);
}
execution_order_
=
re_order
;
}
}
void
KernelGraph
::
CheckLoop
()
{
void
KernelGraph
::
CheckLoop
()
{
...
@@ -360,7 +375,8 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode
...
@@ -360,7 +375,8 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode
void
KernelGraph
::
FrontBackendlMapUpdate
(
const
AnfNodePtr
&
old_backend_anf
,
const
AnfNodePtr
&
new_backend_anf
)
{
void
KernelGraph
::
FrontBackendlMapUpdate
(
const
AnfNodePtr
&
old_backend_anf
,
const
AnfNodePtr
&
new_backend_anf
)
{
MS_EXCEPTION_IF_NULL
(
old_backend_anf
);
MS_EXCEPTION_IF_NULL
(
old_backend_anf
);
MS_EXCEPTION_IF_NULL
(
new_backend_anf
);
MS_EXCEPTION_IF_NULL
(
new_backend_anf
);
if
(
old_backend_anf
.
get
()
==
new_backend_anf
.
get
())
{
if
(
old_backend_anf
==
new_backend_anf
)
{
MS_LOG
(
INFO
)
<<
"old:"
<<
old_backend_anf
->
DebugString
()
<<
",new:"
<<
new_backend_anf
->
DebugString
();
MS_LOG
(
EXCEPTION
)
<<
"old can't be same with new"
;
MS_LOG
(
EXCEPTION
)
<<
"old can't be same with new"
;
}
}
if
(
backend_front_anf_map_
.
find
(
old_backend_anf
)
==
backend_front_anf_map_
.
end
())
{
if
(
backend_front_anf_map_
.
find
(
old_backend_anf
)
==
backend_front_anf_map_
.
end
())
{
...
@@ -569,14 +585,13 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
...
@@ -569,14 +585,13 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
MS_EXCEPTION_IF_NULL
(
new_anf_node
);
MS_EXCEPTION_IF_NULL
(
new_anf_node
);
MS_EXCEPTION_IF_NULL
(
inputs_
);
MS_EXCEPTION_IF_NULL
(
inputs_
);
auto
it
=
node_output_edges_
.
find
(
old_anf_node
);
auto
it
=
node_output_edges_
.
find
(
old_anf_node
);
if
(
it
==
node_output_edges_
.
end
())
{
if
(
it
!=
node_output_edges_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Can't find anf node in node_output_edges map"
;
const
auto
&
outputs
=
it
->
second
;
}
auto
&
outputs
=
it
->
second
;
for
(
auto
&
output_node
:
outputs
)
{
for
(
auto
&
output_node
:
outputs
)
{
MS_EXCEPTION_IF_NULL
(
output_node
.
first
);
auto
output_cnode
=
output_node
.
first
->
cast
<
CNodePtr
>
();
auto
output_cnode
=
output_node
.
first
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
output_cnode
);
MS_EXCEPTION_IF_NULL
(
output_cnode
);
auto
&
output_node_inputs
=
output_cnode
->
inputs
();
const
auto
&
output_node_inputs
=
output_cnode
->
inputs
();
for
(
size_t
i
=
1
;
i
<
output_node_inputs
.
size
();
i
++
)
{
for
(
size_t
i
=
1
;
i
<
output_node_inputs
.
size
();
i
++
)
{
if
(
output_node_inputs
[
i
]
==
old_anf_node
)
{
if
(
output_node_inputs
[
i
]
==
old_anf_node
)
{
output_cnode
->
set_input
(
i
,
new_anf_node
);
output_cnode
->
set_input
(
i
,
new_anf_node
);
...
@@ -585,16 +600,37 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
...
@@ -585,16 +600,37 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
// update graph inputs
// update graph inputs
for
(
size_t
i
=
0
;
i
<
inputs_
->
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
inputs_
->
size
();
i
++
)
{
if
((
*
inputs_
)[
i
]
==
old_anf_node
)
{
if
((
*
inputs_
)[
i
]
==
old_anf_node
)
{
MS_LOG
(
INFO
)
<<
"Replace input of graph:"
<<
graph_id_
<<
", old graph input: "
<<
old_anf_node
->
DebugString
()
<<
",new graph input:"
<<
new_anf_node
->
DebugString
();
(
*
inputs_
)[
i
]
=
new_anf_node
;
(
*
inputs_
)[
i
]
=
new_anf_node
;
break
;
break
;
}
}
}
}
MS_LOG
(
INFO
)
<<
"Inputs of graph id:"
<<
graph_id
();
for
(
size_t
i
=
0
;
i
<
inputs
().
size
();
i
++
)
{
MS_LOG
(
INFO
)
<<
"["
<<
i
<<
"]:"
<<
inputs
()[
i
]
->
DebugString
();
}
}
}
// update front to backend map
// update front to backend map
FrontBackendlMapUpdate
(
old_anf_node
,
new_anf_node
);
FrontBackendlMapUpdate
(
old_anf_node
,
new_anf_node
);
// update output depend relations
// update output depend relations
node_output_edges_
[
new_anf_node
]
=
it
->
second
;
node_output_edges_
[
new_anf_node
]
=
it
->
second
;
(
void
)
node_output_edges_
.
erase
(
old_anf_node
);
(
void
)
node_output_edges_
.
erase
(
old_anf_node
);
}
// update graph inputs in child graph
auto
it_real_inputs
=
real_inputs_
.
find
(
old_anf_node
);
if
(
it_real_inputs
!=
real_inputs_
.
end
())
{
// insert new parameter to map
auto
iter
=
real_inputs_
.
find
(
new_anf_node
);
if
(
iter
!=
real_inputs_
.
end
())
{
MS_LOG
(
WARNING
)
<<
new_anf_node
->
DebugString
()
<<
" already exist in real inputs, will be rewrited."
;
iter
->
second
=
it_real_inputs
->
second
;
}
else
{
real_inputs_
[
new_anf_node
]
=
it_real_inputs
->
second
;
}
// erase old parameter in map
real_inputs_
.
erase
(
old_anf_node
);
}
}
}
void
KernelGraph
::
UpdateExecuteKernelStreamLabel
()
{
void
KernelGraph
::
UpdateExecuteKernelStreamLabel
()
{
...
@@ -603,29 +639,6 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() {
...
@@ -603,29 +639,6 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() {
}
}
}
}
void
KernelGraph
::
UpdateChildGraphOrder
()
{
MS_LOG
(
INFO
)
<<
"graph id:"
<<
graph_id_
;
auto
call_nodes
=
FindNodeByPrimitive
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimCall
->
name
()));
for
(
auto
&
old_child_graph
:
child_graph_order_
)
{
old_child_graph
->
set_parent_graph
(
nullptr
);
}
child_graph_order_
.
clear
();
for
(
auto
&
call_node
:
call_nodes
)
{
MS_EXCEPTION_IF_NULL
(
call_node
);
auto
call_child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
call_node
->
cast
<
CNodePtr
>
());
for
(
const
auto
&
child_graph
:
call_child_graphs
)
{
MS_EXCEPTION_IF_NULL
(
child_graph
);
if
(
child_graph
!=
parent_graph
())
{
child_graph
->
set_parent_graph
(
shared_from_this
()
->
cast
<
std
::
shared_ptr
<
KernelGraph
>>
());
child_graph_order_
.
push_back
(
child_graph
);
}
}
}
for
(
size_t
i
=
0
;
i
<
child_graph_order_
.
size
();
i
++
)
{
MS_LOG
(
INFO
)
<<
"child graph["
<<
i
<<
"][id:"
<<
child_graph_order_
[
i
]
->
graph_id
()
<<
"]"
;
}
}
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
KernelGraph
::
GetLeafGraphOrder
()
{
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
KernelGraph
::
GetLeafGraphOrder
()
{
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
leaf_graph_order
;
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
leaf_graph_order
;
if
(
IsLeafGraph
())
{
if
(
IsLeafGraph
())
{
...
@@ -643,9 +656,8 @@ std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
...
@@ -643,9 +656,8 @@ std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
bool
KernelGraph
::
IsLeafGraph
()
const
{
return
child_graph_order_
.
empty
();
}
bool
KernelGraph
::
IsLeafGraph
()
const
{
return
child_graph_order_
.
empty
();
}
std
::
vector
<
CNodePtr
>
KernelGraph
::
FindNodeByPrimitive
(
const
PrimitivePtr
&
primitive
)
const
{
std
::
vector
<
CNodePtr
>
KernelGraph
::
FindNodeByPrimitive
(
const
PrimitivePtr
&
primitive
)
const
{
auto
anf_list
=
TopoSort
(
get_return
());
std
::
vector
<
CNodePtr
>
result
;
std
::
vector
<
CNodePtr
>
result
;
for
(
const
auto
&
anf
:
anf_list
)
{
for
(
const
auto
&
anf
:
execution_order_
)
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
anf
,
primitive
)
&&
AnfAlgo
::
GetGraphId
(
anf
.
get
())
==
graph_id_
)
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
anf
,
primitive
)
&&
AnfAlgo
::
GetGraphId
(
anf
.
get
())
==
graph_id_
)
{
result
.
push_back
(
anf
->
cast
<
CNodePtr
>
());
result
.
push_back
(
anf
->
cast
<
CNodePtr
>
());
}
}
...
@@ -653,14 +665,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
...
@@ -653,14 +665,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
return
result
;
return
result
;
}
}
std
::
set
<
AnfNodePtr
>
KernelGraph
::
GetRealInput
(
const
AnfNodePtr
&
parameter
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
if
(
real_inputs_
.
find
(
parameter
)
==
real_inputs_
.
end
())
{
return
{};
}
return
real_inputs_
[
parameter
];
}
void
KernelGraph
::
SetRealInput
(
const
AnfNodePtr
&
parameter
,
const
AnfNodePtr
&
arg
)
{
void
KernelGraph
::
SetRealInput
(
const
AnfNodePtr
&
parameter
,
const
AnfNodePtr
&
arg
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
MS_EXCEPTION_IF_NULL
(
parameter
);
MS_EXCEPTION_IF_NULL
(
arg
);
MS_EXCEPTION_IF_NULL
(
arg
);
...
@@ -674,39 +678,43 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar
...
@@ -674,39 +678,43 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar
(
void
)
args
.
insert
(
arg
);
(
void
)
args
.
insert
(
arg
);
}
}
std
::
set
<
AnfNodePtr
>
KernelGraph
::
GetRealInput
(
const
AnfNodePtr
&
parameter
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
auto
iter
=
real_inputs_
.
find
(
parameter
);
if
(
iter
!=
real_inputs_
.
end
())
{
return
iter
->
second
;
}
MS_LOG
(
EXCEPTION
)
<<
parameter
->
DebugString
()
<<
" not found."
;
}
void
KernelGraph
::
UpdateCallRealInput
()
{
void
KernelGraph
::
UpdateCallRealInput
()
{
MS_LOG
(
INFO
)
<<
"Update graph id: "
<<
graph_id_
;
MS_LOG
(
INFO
)
<<
"Update graph id: "
<<
graph_id_
;
for
(
auto
&
it
:
real_inputs_
)
{
for
(
auto
&
it
:
real_inputs_
)
{
auto
&
parameter
=
it
.
first
;
auto
&
parameter
=
it
.
first
;
MS_EXCEPTION_IF_NULL
(
parameter
);
MS_EXCEPTION_IF_NULL
(
parameter
);
auto
&
real_inputs
=
it
.
second
;
auto
&
real_inputs
=
it
.
second
;
std
::
set
<
AnfNodePtr
>
new_real_inputs
;
std
::
vector
<
AnfNodePtr
>
new_real_inputs
;
std
::
set
<
AnfNodePtr
>
erase_real_inputs
;
std
::
set
<
AnfNodePtr
>
erase_real_inputs
;
for
(
auto
&
real_input
:
real_inputs
)
{
for
(
auto
&
real_input
:
real_inputs
)
{
// if real input is a call node ,find the child graph output act as the new real input
// if real input is a call node ,find the child graph output act as the new real input
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
real_input
,
0
);
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
real_input
,
0
);
MS_EXCEPTION_IF_NULL
(
item_with_index
.
first
);
MS_EXCEPTION_IF_NULL
(
item_with_index
.
first
);
if
(
AnfAlgo
::
CheckPrimitiveType
(
item_with_index
.
first
,
prim
::
kPrimCall
))
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
item_with_index
.
first
,
prim
::
kPrimCall
))
{
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" erase real input:"
<<
item_with_index
.
first
->
DebugString
();
(
void
)
erase_real_inputs
.
insert
(
item_with_index
.
first
);
(
void
)
erase_real_inputs
.
insert
(
item_with_index
.
first
);
auto
call_node_outputs
=
GetCallRealOutputs
(
item_with_index
.
first
);
new_real_inputs
=
GetCallRealOutputs
(
item_with_index
.
first
);
for
(
auto
&
call_node_output
:
call_node_outputs
)
{
MS_EXCEPTION_IF_NULL
(
call_node_output
);
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" insert real input:"
<<
call_node_output
->
DebugString
();
(
void
)
new_real_inputs
.
insert
(
call_node_output
);
}
continue
;
continue
;
}
}
}
for
(
auto
&
erase_node
:
erase_real_inputs
)
{
for
(
auto
&
erase_node
:
erase_real_inputs
)
{
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" erase real input:"
<<
erase_node
->
DebugString
();
(
void
)
real_inputs
.
erase
(
erase_node
);
(
void
)
real_inputs
.
erase
(
erase_node
);
}
}
for
(
auto
&
new_real_input
:
new_real_inputs
)
{
for
(
auto
&
new_real_input
:
new_real_inputs
)
{
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" insert real input:"
<<
new_real_input
->
DebugString
();
(
void
)
real_inputs
.
insert
(
new_real_input
);
(
void
)
real_inputs
.
insert
(
new_real_input
);
}
}
}
}
}
}
}
std
::
string
KernelGraph
::
ToString
()
const
{
return
std
::
string
(
"kernel_graph_"
).
append
(
std
::
to_string
(
graph_id_
));
}
std
::
string
KernelGraph
::
ToString
()
const
{
return
std
::
string
(
"kernel_graph_"
).
append
(
std
::
to_string
(
graph_id_
));
}
...
...
mindspore/ccsrc/session/kernel_graph.h
浏览文件 @
5aae0d91
...
@@ -103,10 +103,9 @@ class KernelGraph : public FuncGraph {
...
@@ -103,10 +103,9 @@ class KernelGraph : public FuncGraph {
void
UpdateExecuteKernelStreamLabel
();
void
UpdateExecuteKernelStreamLabel
();
// calculate the leaf graph order of root graph
// calculate the leaf graph order of root graph
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
GetLeafGraphOrder
();
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
GetLeafGraphOrder
();
// update the child graph order of graph
// the child graph of current graph
void
UpdateChildGraphOrder
();
const
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
&
child_graph_order
()
const
{
return
child_graph_order_
;
}
// get the child graph of current graph
void
set_child_graph_order
(
const
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
&
order
)
{
child_graph_order_
=
order
;
}
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
child_graph_order
()
const
{
return
child_graph_order_
;
}
// checkout whether current graph is leaf graph
// checkout whether current graph is leaf graph
bool
IsLeafGraph
()
const
;
bool
IsLeafGraph
()
const
;
...
@@ -123,6 +122,7 @@ class KernelGraph : public FuncGraph {
...
@@ -123,6 +122,7 @@ class KernelGraph : public FuncGraph {
// find anf node in graph
// find anf node in graph
std
::
vector
<
CNodePtr
>
FindNodeByPrimitive
(
const
PrimitivePtr
&
primitive
)
const
;
std
::
vector
<
CNodePtr
>
FindNodeByPrimitive
(
const
PrimitivePtr
&
primitive
)
const
;
// get real inputs
// get real inputs
const
std
::
map
<
AnfNodePtr
,
std
::
set
<
AnfNodePtr
>>
&
real_inputs
()
const
{
return
real_inputs_
;
}
std
::
set
<
AnfNodePtr
>
GetRealInput
(
const
AnfNodePtr
&
parameter
);
std
::
set
<
AnfNodePtr
>
GetRealInput
(
const
AnfNodePtr
&
parameter
);
void
SetRealInput
(
const
AnfNodePtr
&
parameter
,
const
AnfNodePtr
&
arg
);
void
SetRealInput
(
const
AnfNodePtr
&
parameter
,
const
AnfNodePtr
&
arg
);
// used to dump ir
// used to dump ir
...
@@ -132,6 +132,8 @@ class KernelGraph : public FuncGraph {
...
@@ -132,6 +132,8 @@ class KernelGraph : public FuncGraph {
void
set_start_label
(
const
CNodePtr
&
start_label
)
{
start_label_
=
start_label
;
}
void
set_start_label
(
const
CNodePtr
&
start_label
)
{
start_label_
=
start_label
;
}
CNodePtr
get_start_label
()
{
return
start_label_
;
}
CNodePtr
get_start_label
()
{
return
start_label_
;
}
void
set_end_goto
(
const
CNodePtr
&
end_goto
)
{
end_goto_
=
end_goto
;
}
CNodePtr
get_end_goto
()
{
return
end_goto_
;
}
private:
private:
// remove value node form graph
// remove value node form graph
...
@@ -185,6 +187,7 @@ class KernelGraph : public FuncGraph {
...
@@ -185,6 +187,7 @@ class KernelGraph : public FuncGraph {
std
::
map
<
AnfNodePtr
,
std
::
set
<
AnfNodePtr
>>
real_inputs_
;
std
::
map
<
AnfNodePtr
,
std
::
set
<
AnfNodePtr
>>
real_inputs_
;
CNodePtr
start_label_
;
CNodePtr
start_label_
;
CNodePtr
end_goto_
;
};
};
}
// namespace session
}
// namespace session
using
KernelGraphPtr
=
std
::
shared_ptr
<
session
::
KernelGraph
>
;
using
KernelGraphPtr
=
std
::
shared_ptr
<
session
::
KernelGraph
>
;
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
5aae0d91
...
@@ -147,6 +147,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
...
@@ -147,6 +147,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
MS_LOG
(
INFO
)
<<
"create tensor for output["
<<
anf
->
DebugString
()
<<
"]"
;
MS_LOG
(
INFO
)
<<
"create tensor for output["
<<
anf
->
DebugString
()
<<
"]"
;
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
anf
,
0
);
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
anf
,
0
);
MS_EXCEPTION_IF_NULL
(
item_with_index
.
first
);
MS_EXCEPTION_IF_NULL
(
item_with_index
.
first
);
MS_LOG
(
INFO
)
<<
"create tensor for output after visit:"
<<
item_with_index
.
first
->
DebugString
();
// special handle for maketuple
// special handle for maketuple
if
(
AnfAlgo
::
CheckPrimitiveType
(
item_with_index
.
first
,
prim
::
kPrimMakeTuple
))
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
item_with_index
.
first
,
prim
::
kPrimMakeTuple
))
{
auto
cnode
=
item_with_index
.
first
->
cast
<
CNodePtr
>
();
auto
cnode
=
item_with_index
.
first
->
cast
<
CNodePtr
>
();
...
@@ -479,31 +480,12 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
...
@@ -479,31 +480,12 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
}
}
for
(
size_t
input_idx
=
1
;
input_idx
<
cnode
->
inputs
().
size
();
input_idx
++
)
{
for
(
size_t
input_idx
=
1
;
input_idx
<
cnode
->
inputs
().
size
();
input_idx
++
)
{
auto
anf
=
cnode
->
input
s
()[
input_idx
]
;
auto
anf
=
cnode
->
input
(
input_idx
)
;
MS_EXCEPTION_IF_NULL
(
anf
);
MS_EXCEPTION_IF_NULL
(
anf
);
// anf has been created before
// anf has been created before
if
(
graph
->
GetBackendAnfByFrontAnf
(
anf
)
!=
nullptr
)
{
if
(
graph
->
GetBackendAnfByFrontAnf
(
anf
)
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
graph
->
GetBackendAnfByFrontAnf
(
anf
));
cnode_inputs
.
emplace_back
(
graph
->
GetBackendAnfByFrontAnf
(
anf
));
continue
;
continue
;
}
else
if
(
anf
->
isa
<
ValueNode
>
())
{
if
(
!
IsValueNode
<
FuncGraph
>
(
anf
))
{
// if input is a common value node,
auto
new_value_node
=
CreateNewValueNode
(
anf
,
graph
);
if
(
new_value_node
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
new_value_node
);
}
}
else
{
// if input is a ValueNode<FuncGraph>
auto
new_value_node
=
CreateValueNodeKernelGraph
(
anf
,
graph
);
if
(
new_value_node
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
new_value_node
);
}
}
continue
;
}
else
if
(
anf
->
isa
<
Parameter
>
())
{
auto
new_parameter
=
CreateNewParameter
(
anf
,
graph
);
cnode_inputs
.
push_back
(
new_parameter
);
continue
;
}
}
MS_LOG
(
EXCEPTION
)
<<
"Unexpected input["
<<
anf
->
DebugString
()
<<
"]"
;
MS_LOG
(
EXCEPTION
)
<<
"Unexpected input["
<<
anf
->
DebugString
()
<<
"]"
;
}
}
...
@@ -613,32 +595,22 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
...
@@ -613,32 +595,22 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
for
(
const
auto
&
node
:
node_list
)
{
for
(
const
auto
&
node
:
node_list
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_LOG
(
DEBUG
)
<<
"Start create new cnode, node = "
<<
node
->
DebugString
();
MS_LOG
(
DEBUG
)
<<
"Start create new cnode, node = "
<<
node
->
DebugString
();
if
(
!
node
->
isa
<
CNode
>
())
{
if
(
node
->
isa
<
Parameter
>
())
{
MS_LOG
(
DEBUG
)
<<
"Node "
<<
node
->
DebugString
()
<<
" is not CNode"
;
(
void
)
CreateNewParameter
(
node
,
graph
.
get
());
continue
;
}
else
if
(
node
->
isa
<
ValueNode
>
())
{
if
(
!
IsValueNode
<
FuncGraph
>
(
node
))
{
// if input is a common value node,
(
void
)
CreateNewValueNode
(
node
,
graph
.
get
());
}
else
{
// if input is a ValueNode<FuncGraph>
auto
child_graph
=
ConstructKernelGraph
(
AnfAlgo
::
GetValueNodeFuncGraph
(
node
));
auto
new_value_node
=
CreateValueNodeKernelGraph
(
node
,
graph
.
get
());
}
continue
;
continue
;
}
else
{
}
else
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
cnode
);
// recurse control ops: call, partial
auto
attr_input
=
cnode
->
input
(
kAnfPrimitiveIndex
);
MS_EXCEPTION_IF_NULL
(
attr_input
);
if
(
IsValueNode
<
FuncGraph
>
(
attr_input
))
{
// recurse call subgraph
auto
sub_func_graph
=
AnfAlgo
::
GetValueNodeFuncGraph
(
attr_input
);
ConstructKernelGraph
(
sub_func_graph
);
}
else
if
(
IsValueNode
<
Primitive
>
(
attr_input
))
{
auto
prim
=
GetCNodePrimitive
(
node
);
MS_EXCEPTION_IF_NULL
(
prim
);
if
(
prim
->
name
()
==
kPartialOpName
)
{
// recurse partial subgraph
auto
func_graph_node
=
cnode
->
input
(
kAnfPartialFuncGraphIndex
);
MS_EXCEPTION_IF_NULL
(
func_graph_node
);
auto
sub_func_graph
=
AnfAlgo
::
GetValueNodeFuncGraph
(
func_graph_node
);
ConstructKernelGraph
(
sub_func_graph
);
}
}
// create a new cnode object
// create a new cnode object
auto
new_cnode
=
CreateNewCNode
(
cnode
,
graph
.
get
());
auto
new_cnode
=
CreateNewCNode
(
cnode
,
graph
.
get
());
MS_EXCEPTION_IF_NULL
(
new_cnode
);
MS_EXCEPTION_IF_NULL
(
new_cnode
);
...
@@ -650,7 +622,21 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
...
@@ -650,7 +622,21 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
}
}
}
}
}
}
auto
graph_inputs
=
graph
->
MutableInputs
();
MS_EXCEPTION_IF_NULL
(
graph_inputs
);
graph_inputs
->
clear
();
for
(
auto
&
parameter
:
func_graph
->
parameters
())
{
MS_EXCEPTION_IF_NULL
(
parameter
);
auto
backend_parameter
=
graph
->
GetBackendAnfByFrontAnf
(
parameter
);
if
(
backend_parameter
==
nullptr
)
{
// for example "def f(x,y,z) {return x + y}", parameter z in unused
CreateNewParameterFromParameter
(
parameter
,
false
,
graph
.
get
());
MS_LOG
(
INFO
)
<<
"Can't find parameter:"
<<
parameter
->
DebugString
();
continue
;
}
MS_LOG
(
INFO
)
<<
"graph["
<<
graph
->
graph_id
()
<<
"],parameter:"
<<
parameter
->
DebugString
();
graph_inputs
->
push_back
(
backend_parameter
);
}
MS_EXCEPTION_IF_NULL
(
context_
);
MS_EXCEPTION_IF_NULL
(
context_
);
FuncGraphManagerPtr
manager
=
context_
->
manager
();
FuncGraphManagerPtr
manager
=
context_
->
manager
();
if
(
manager
)
{
if
(
manager
)
{
...
@@ -716,6 +702,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
...
@@ -716,6 +702,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
const
{
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
const
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
outputs
);
MS_EXCEPTION_IF_NULL
(
outputs
);
if
(
!
kernel_graph
->
child_graph_order
().
empty
())
{
// use the last child graph output as the root graph output
UpdateOutputs
(
kernel_graph
->
child_graph_order
().
back
(),
outputs
,
input_tensors
);
return
;
}
auto
anf_outputs
=
kernel_graph
->
outputs
();
auto
anf_outputs
=
kernel_graph
->
outputs
();
for
(
auto
&
item
:
anf_outputs
)
{
for
(
auto
&
item
:
anf_outputs
)
{
MS_LOG
(
INFO
)
<<
"update output["
<<
item
->
DebugString
()
<<
"]"
;
MS_LOG
(
INFO
)
<<
"update output["
<<
item
->
DebugString
()
<<
"]"
;
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
5aae0d91
...
@@ -487,8 +487,7 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
...
@@ -487,8 +487,7 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
}
}
void
TraverseGraphMap
(
void
TraverseGraphMap
(
const
FuncGraphManagerPtr
&
manager_ptr
,
FuncGraphTransaction
*
const
tr
,
const
FuncGraphManagerPtr
&
manager_ptr
,
FuncGraphTransaction
*
const
tr
,
const
FuncGraphSet
&
fgs
,
const
FuncGraphSet
&
fgs
,
const
std
::
function
<
std
::
shared_ptr
<
FuncGraph
>
(
const
PrimitivePtr
,
const
AbstractFunctionPtr
)
>
&
get_prim_graph
)
{
const
std
::
function
<
std
::
shared_ptr
<
FuncGraph
>
(
const
PrimitivePtr
,
const
AbstractFunctionPtr
)
>
&
get_prim_graph
)
{
MS_EXCEPTION_IF_NULL
(
manager_ptr
);
MS_EXCEPTION_IF_NULL
(
manager_ptr
);
MS_EXCEPTION_IF_NULL
(
tr
);
MS_EXCEPTION_IF_NULL
(
tr
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录