Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9fe6074c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9fe6074c
编写于
5月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1384 [control sink refactor]Update real input if it is a call
Merge pull request !1384 from chenfei_mindspore/sort-call-node
上级
b8e25c38
5e9edc16
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
116 addition
and
32 deletion
+116
-32
mindspore/ccsrc/session/anf_runtime_algorithm.cc
mindspore/ccsrc/session/anf_runtime_algorithm.cc
+12
-2
mindspore/ccsrc/session/anf_runtime_algorithm.h
mindspore/ccsrc/session/anf_runtime_algorithm.h
+1
-0
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+32
-15
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+3
-2
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+63
-3
mindspore/ccsrc/session/kernel_graph.h
mindspore/ccsrc/session/kernel_graph.h
+2
-0
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+3
-10
未找到文件。
mindspore/ccsrc/session/anf_runtime_algorithm.cc
浏览文件 @
9fe6074c
...
...
@@ -942,7 +942,6 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
}
else
if
(
input1
->
isa
<
CNode
>
()
&&
AnfAlgo
::
CheckPrimitiveType
(
input1
,
prim
::
kPrimSwitch
))
{
auto
switch_node
=
input1
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
switch_node
);
MS_LOG
(
INFO
)
<<
"switch : "
<<
switch_node
->
DebugString
();
auto
get_switch_kernel_graph
=
[
&
](
size_t
input_index
)
->
KernelGraphPtr
{
auto
partial
=
switch_node
->
input
(
input_index
);
MS_EXCEPTION_IF_NULL
(
partial
);
...
...
@@ -950,7 +949,6 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
MS_EXCEPTION_IF_NULL
(
partial_cnode
);
auto
graph_node
=
partial_cnode
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
graph_node
);
MS_LOG
(
INFO
)
<<
graph_node
->
DebugString
();
auto
graph_value_node
=
graph_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
graph_value_node
);
auto
graph_value
=
graph_value_node
->
value
();
...
...
@@ -976,5 +974,17 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
}
MS_LOG
(
EXCEPTION
)
<<
"Unexpected input1 of call node,input1:"
<<
input1
->
DebugString
();
}
bool
AnfRuntimeAlgorithm
::
IsWhileTrueGraph
(
const
KernelGraphPtr
&
child_graph
)
{
auto
call_nodes
=
child_graph
->
FindNodeByPrimitive
(
prim
::
kPrimCall
);
for
(
const
auto
&
call_node
:
call_nodes
)
{
auto
graphs
=
GetCallNodeKernelGraph
(
call_node
);
if
(
graphs
.
size
()
==
1
&&
graphs
[
0
]
==
child_graph
->
parent_graph
())
{
return
true
;
}
}
return
false
;
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/anf_runtime_algorithm.h
浏览文件 @
9fe6074c
...
...
@@ -185,6 +185,7 @@ class AnfRuntimeAlgorithm {
static
FuncGraphPtr
GetValueNodeFuncGraph
(
const
AnfNodePtr
&
node
);
static
std
::
vector
<
KernelGraphPtr
>
GetCallNodeKernelGraph
(
const
CNodePtr
&
call_node
);
static
bool
IsSwitchCall
(
const
CNodePtr
&
call_node
);
static
bool
IsWhileTrueGraph
(
const
KernelGraphPtr
&
child_graph
);
};
}
// namespace session
using
AnfAlgo
=
session
::
AnfRuntimeAlgorithm
;
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
9fe6074c
...
...
@@ -18,6 +18,7 @@
#include <map>
#include <tuple>
#include <set>
#include <list>
#include "operator/ops.h"
#include "ir/meta_tensor.h"
#include "ir/anf.h"
...
...
@@ -160,7 +161,7 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) {
std
::
vector
<
CNodePtr
>
GetCNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
anf_nodes
)
{
std
::
vector
<
CNodePtr
>
cnodes
=
{};
size_t
i
=
0
;
for
(
auto
anf
:
anf_nodes
)
{
for
(
const
auto
&
anf
:
anf_nodes
)
{
MS_LOG
(
INFO
)
<<
"apply_list["
<<
i
++
<<
"] = "
<<
anf
->
DebugString
();
MS_EXCEPTION_IF_NULL
(
anf
);
if
(
anf
->
isa
<
CNode
>
())
{
...
...
@@ -192,6 +193,8 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
return
ret
;
}
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
void
UpdateRealInput
(
KernelGraph
*
graph
)
{
auto
call_nodes
=
graph
->
FindNodeByPrimitive
(
prim
::
kPrimCall
);
auto
bind_call_partial_with_parameter
=
[
&
](
const
std
::
vector
<
AnfNodePtr
>
&
parameters
,
...
...
@@ -239,6 +242,15 @@ void UpdateRealInput(KernelGraph *graph) {
}
}
}
void
RecurseToUpdateCallRealInput
(
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_LOG
(
INFO
)
<<
"start graph id:"
<<
graph
->
graph_id
();
graph
->
UpdateCallRealInput
();
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
RecurseToUpdateCallRealInput
(
child_graph
.
get
());
}
}
}
// namespace
GraphId
AscendSession
::
CompileGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
{
...
...
@@ -254,7 +266,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG
(
INFO
)
<<
"start"
;
auto
graph
=
ConstructKernelGraph
(
func_graph
);
// split switch
SplitGraph
(
graph
);
SplitGraph
s
(
graph
);
// insert goto labels and label_sets
LinkChildGraphs
(
NOT_NULL
(
graph
));
// resource initialize
...
...
@@ -1366,7 +1378,7 @@ void AscendSession::SyncInitialTenosrToDevice() {
}
}
KernelGraphPtr
AscendSession
::
SplitKernel
Graph
(
const
KernelGraphPtr
&
new_kernel_graph
,
KernelGraphPtr
AscendSession
::
ConstructSplited
Graph
(
const
KernelGraphPtr
&
new_kernel_graph
,
const
std
::
vector
<
CNodePtr
>
&
list
)
{
MS_EXCEPTION_IF_NULL
(
new_kernel_graph
);
MS_LOG
(
INFO
)
<<
"start split kernel graph:"
<<
new_kernel_graph
->
graph_id
();
...
...
@@ -1376,9 +1388,6 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_
for
(
auto
&
input
:
anf_node
->
inputs
())
{
(
void
)
has_output_nodes
.
insert
(
input
);
}
if
(
AnfAlgo
::
CheckPrimitiveType
(
anf_node
,
prim
::
kPrimReturn
))
{
new_kernel_graph
->
set_return
(
anf_node
->
cast
<
CNodePtr
>
());
}
}
MS_LOG
(
INFO
)
<<
"Construct input of kernel graph:"
<<
new_kernel_graph
->
graph_id
();
// create new parameter from cnode
...
...
@@ -1386,6 +1395,7 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_
auto
cnode
=
anf_node
->
cast
<
CNodePtr
>
();
for
(
size_t
input_idx
=
1
;
input_idx
<
cnode
->
inputs
().
size
();
input_idx
++
)
{
auto
input
=
cnode
->
inputs
()[
input_idx
];
MS_EXCEPTION_IF_NULL
(
input
);
if
(
!
input
->
isa
<
CNode
>
())
{
cnode
->
set_input
(
input_idx
,
input
);
continue
;
...
...
@@ -1417,6 +1427,12 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_
return
new_kernel_graph
;
}
void
AscendSession
::
SplitGraphs
(
const
KernelGraphPtr
&
root_graph
)
{
SplitGraph
(
root_graph
);
// replace the real input if the real input is a call
RecurseToUpdateCallRealInput
(
root_graph
.
get
());
}
void
AscendSession
::
SplitGraph
(
const
KernelGraphPtr
&
graph
)
{
MS_LOG
(
INFO
)
<<
"start,graph_id:"
<<
graph
->
graph_id
();
MS_EXCEPTION_IF_NULL
(
graph
);
...
...
@@ -1426,6 +1442,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
// get child list from current graph
std
::
vector
<
std
::
vector
<
CNodePtr
>>
child_graph_lists
=
GetChildList
(
*
graph
,
apply_list
);
auto
bind_new_call_to_new_graph
=
[
&
](
std
::
vector
<
CNodePtr
>
child_graph_list
)
->
AnfNodePtr
{
// if child graph list only has a call ,then return the exist call
if
(
child_graph_list
.
size
()
==
1
&&
AnfAlgo
::
CheckPrimitiveType
(
child_graph_list
[
0
],
prim
::
kPrimCall
))
{
return
child_graph_list
[
0
];
}
...
...
@@ -1440,22 +1457,22 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
for
(
auto
&
child_graph_node
:
child_graph_list
)
{
AnfAlgo
::
SetGraphId
(
child_graph
->
graph_id
(),
child_graph_node
.
get
());
}
SplitKernel
Graph
(
child_graph
,
child_graph_list
);
ConstructSplited
Graph
(
child_graph
,
child_graph_list
);
auto
new_call
=
graph
->
NewCNode
(
new_call_input
);
AnfAlgo
::
SetNodeAttr
(
"graph id"
,
MakeValue
(
graph
->
graph_id
()),
new_call
);
return
new_call
;
};
if
(
child_graph_lists
.
size
()
>
1
)
{
std
::
list
<
AnfNodePtr
>
depend_input
=
{};
for
(
size_t
call_index
=
0
;
call_index
<
child_graph_lists
.
size
();
call_index
++
)
{
auto
call_node
=
bind_new_call_to_new_graph
(
child_graph_lists
[
call_index
]);
if
(
call_index
==
0
)
{
depend_input
.
push_front
(
call_node
);
}
depend_input
.
push_front
(
graph
->
NewValueNode
(
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimDepend
->
name
()))));
auto
depend
=
graph
->
NewCNode
(
std
::
vector
<
AnfNodePtr
>
(
depend_input
.
begin
(),
depend_input
.
end
()));
auto
new_return_primitive
=
graph
->
NewValueNode
(
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimReturn
->
name
())));
graph
->
set_return
(
graph
->
NewCNode
({
new_return_primitive
,
call_node
}));
continue
;
}
InsertDependToGraph
(
graph
->
graph_id
(),
call_node
);
}
graph
->
set_return
(
graph
->
NewCNode
({
new_return_primitive
,
depend
}));
}
graph
->
UpdateChildGraphOrder
();
UpdateRealInput
(
graph
.
get
());
...
...
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
9fe6074c
...
...
@@ -97,15 +97,16 @@ class AscendSession : public SessionBasic {
void
SetFinalGraphOutput
(
const
VectorRef
&
vec_output
);
void
SplitGraph
(
const
KernelGraphPtr
&
graph
);
// split graphs with recurse from root graph
void
SplitGraphs
(
const
KernelGraphPtr
&
root_graph
);
void
LinkChildGraphs
(
NotNull
<
KernelGraphPtr
>
graph
);
void
IRFusion
(
const
KernelGraphPtr
&
graph
)
{}
void
SelectKernelGraphKernel
(
const
KernelGraph
&
graph
)
{}
void
ConvertPredictModel
(
const
KernelGraphPtr
graph
)
{}
void
HardwareOptimizeGraphs
(
const
KernelGraphPtr
graph
)
{}
void
RootGraphExecutorValidate
(
KernelGraph
*
graph
)
{}
void
RecurseUpdateAllChildGraohOrder
(
KernelGraph
*
root_graph
);
KernelGraphPtr
SplitKernel
Graph
(
const
KernelGraphPtr
&
new_kernel_graph
,
const
std
::
vector
<
CNodePtr
>
&
list
);
KernelGraphPtr
ConstructSplited
Graph
(
const
KernelGraphPtr
&
new_kernel_graph
,
const
std
::
vector
<
CNodePtr
>
&
list
);
void
ChildGraphCommunicationDecrease
(
std
::
vector
<
std
::
vector
<
AnfNodePtr
>>
*
anf_node_lists
);
// merge execution order list of child graphs
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
9fe6074c
...
...
@@ -39,16 +39,35 @@ void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
MS_LOG
(
DEBUG
)
<<
"Push que:"
<<
node
->
DebugString
();
}
}
std
::
vector
<
AnfNodePtr
>
GetCallRealOutputs
(
const
AnfNodePtr
&
call_node
)
{
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
call_node
,
0
);
MS_EXCEPTION_IF_NULL
(
item_with_index
.
first
);
if
(
!
AnfAlgo
::
CheckPrimitiveType
(
item_with_index
.
first
,
prim
::
kPrimCall
))
{
return
{
item_with_index
.
first
};
}
std
::
vector
<
AnfNodePtr
>
real_inputs
;
auto
child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
item_with_index
.
first
->
cast
<
CNodePtr
>
());
for
(
const
auto
&
child_graph
:
child_graphs
)
{
if
(
AnfAlgo
::
IsWhileTrueGraph
(
child_graph
))
{
continue
;
}
auto
real_input
=
child_graph
->
output
();
auto
child_real_inputs
=
GetCallRealOutputs
(
real_input
);
std
::
copy
(
child_real_inputs
.
begin
(),
child_real_inputs
.
end
(),
std
::
back_inserter
(
real_inputs
));
}
return
real_inputs
;
}
}
// namespace
std
::
vector
<
AnfNodePtr
>
KernelGraph
::
outputs
()
const
{
MS_EXCEPTION_IF_NULL
(
output
()
);
if
(
IsPrimitiveCNode
(
output
()
,
prim
::
kPrimMakeTuple
))
{
auto
graph_output
=
output
(
);
if
(
IsPrimitiveCNode
(
graph_output
,
prim
::
kPrimMakeTuple
))
{
auto
make_tuple
=
output
()
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
make_tuple
);
auto
&
inputs
=
make_tuple
->
inputs
();
return
std
::
vector
<
AnfNodePtr
>
(
inputs
.
begin
()
+
1
,
inputs
.
end
());
}
return
std
::
vector
<
AnfNodePtr
>
();
return
std
::
vector
<
AnfNodePtr
>
(
1
,
graph_output
);
}
void
KernelGraph
::
VisitNodeDescendants
(
const
AnfNodePtr
&
node
,
std
::
queue
<
AnfNodePtr
>
*
visit_queue
,
...
...
@@ -587,6 +606,9 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() {
void
KernelGraph
::
UpdateChildGraphOrder
()
{
MS_LOG
(
INFO
)
<<
"graph id:"
<<
graph_id_
;
auto
call_nodes
=
FindNodeByPrimitive
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimCall
->
name
()));
for
(
auto
&
old_child_graph
:
child_graph_order_
)
{
old_child_graph
->
set_parent_graph
(
nullptr
);
}
child_graph_order_
.
clear
();
for
(
auto
&
call_node
:
call_nodes
)
{
MS_EXCEPTION_IF_NULL
(
call_node
);
...
...
@@ -640,6 +662,9 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) {
}
void
KernelGraph
::
SetRealInput
(
const
AnfNodePtr
&
parameter
,
const
AnfNodePtr
&
arg
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
MS_EXCEPTION_IF_NULL
(
arg
);
MS_LOG
(
INFO
)
<<
"parameter: "
<<
parameter
->
DebugString
()
<<
", real input : "
<<
arg
->
DebugString
();
MS_EXCEPTION_IF_NULL
(
parameter
);
MS_EXCEPTION_IF_NULL
(
arg
);
if
(
real_inputs_
.
find
(
parameter
)
==
real_inputs_
.
end
())
{
...
...
@@ -649,6 +674,41 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar
(
void
)
args
.
insert
(
arg
);
}
void
KernelGraph
::
UpdateCallRealInput
()
{
MS_LOG
(
INFO
)
<<
"Update graph id: "
<<
graph_id_
;
for
(
auto
&
it
:
real_inputs_
)
{
auto
&
parameter
=
it
.
first
;
MS_EXCEPTION_IF_NULL
(
parameter
);
auto
&
real_inputs
=
it
.
second
;
std
::
set
<
AnfNodePtr
>
new_real_inputs
;
std
::
set
<
AnfNodePtr
>
erase_real_inputs
;
for
(
auto
&
real_input
:
real_inputs
)
{
// if real input is a call node ,find the child graph output act as the new real input
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
real_input
,
0
);
MS_EXCEPTION_IF_NULL
(
item_with_index
.
first
);
if
(
AnfAlgo
::
CheckPrimitiveType
(
item_with_index
.
first
,
prim
::
kPrimCall
))
{
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" erase real input:"
<<
item_with_index
.
first
->
DebugString
();
(
void
)
erase_real_inputs
.
insert
(
item_with_index
.
first
);
auto
call_node_outputs
=
GetCallRealOutputs
(
item_with_index
.
first
);
for
(
auto
&
call_node_output
:
call_node_outputs
)
{
MS_EXCEPTION_IF_NULL
(
call_node_output
);
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" insert real input:"
<<
call_node_output
->
DebugString
();
(
void
)
new_real_inputs
.
insert
(
call_node_output
);
}
continue
;
}
for
(
auto
&
erase_node
:
erase_real_inputs
)
{
(
void
)
real_inputs
.
erase
(
erase_node
);
}
for
(
auto
&
new_real_input
:
new_real_inputs
)
{
(
void
)
real_inputs
.
insert
(
new_real_input
);
}
}
}
}
std
::
string
KernelGraph
::
ToString
()
const
{
return
std
::
string
(
"kernel_graph_"
).
append
(
std
::
to_string
(
graph_id_
));
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/kernel_graph.h
浏览文件 @
9fe6074c
...
...
@@ -127,6 +127,8 @@ class KernelGraph : public FuncGraph {
void
SetRealInput
(
const
AnfNodePtr
&
parameter
,
const
AnfNodePtr
&
arg
);
// used to dump ir
std
::
string
ToString
()
const
override
;
// update the real input if the node is a call
void
UpdateCallRealInput
();
void
set_start_label
(
const
CNodePtr
&
start_label
)
{
start_label_
=
start_label
;
}
CNodePtr
get_start_label
()
{
return
start_label_
;
}
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
9fe6074c
...
...
@@ -640,16 +640,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL
(
func_graph_node
);
auto
sub_func_graph
=
AnfAlgo
::
GetValueNodeFuncGraph
(
func_graph_node
);
ConstructKernelGraph
(
sub_func_graph
);
}
else
if
(
prim
->
name
()
==
kReturnOpName
)
{
std
::
vector
<
AnfNodePtr
>
outputs
;
auto
inputs
=
cnode
->
inputs
();
if
(
inputs
.
size
()
<
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"CNode[return] must have two inputs at least, actual inputs size is "
<<
inputs
.
size
();
}
(
void
)
std
::
copy
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
outputs
));
// add a make_tuple before return as graph output
graph
->
set_output
(
ConstructOutput
(
outputs
,
graph
));
continue
;
}
}
...
...
@@ -659,6 +649,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
new_cnode
->
set_abstract
(
cnode
->
abstract
());
new_cnode
->
set_scope
(
cnode
->
scope
());
graph
->
FrontBackendlMapAdd
(
node
,
new_cnode
);
if
(
AnfAlgo
::
CheckPrimitiveType
(
new_cnode
,
prim
::
kPrimReturn
))
{
graph
->
set_return
(
new_cnode
);
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录