Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
dcb90588
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dcb90588
编写于
6月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2248 bind summary nodes to KernelGraph in order to memory reuse
Merge pull request !2248 from Margaret_wangrui/r0.3
上级
476671b1
6f5303f0
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
104 addition
and
42 deletion
+104
-42
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+36
-0
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+2
-0
mindspore/ccsrc/session/kernel_graph.h
mindspore/ccsrc/session/kernel_graph.h
+10
-0
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+55
-42
mindspore/ccsrc/session/session_basic.h
mindspore/ccsrc/session/session_basic.h
+1
-0
未找到文件。
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
dcb90588
...
...
@@ -321,6 +321,18 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
return
graph_id
;
}
void
AscendSession
::
SetFinalGraphSummaryFlag
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
{
auto
graph_order
=
GetGraphOrder
(
kernel_graph
->
graph_id
());
for
(
auto
graph_id
:
graph_order
)
{
auto
child_graph
=
GetGraph
(
graph_id
);
if
(
child_graph
->
summary_node_exist
())
{
kernel_graph
->
set_summary_node_exist
(
true
);
return
;
}
}
kernel_graph
->
set_summary_node_exist
(
false
);
}
void
AscendSession
::
BuildGraph
(
GraphId
graph_id
)
{
MS_LOG
(
INFO
)
<<
"start"
;
auto
graph
=
GetGraph
(
graph_id
);
...
...
@@ -336,6 +348,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
InsertAllAssigns
();
// insert switch and active to child graph
MergeSwitchCompile
();
SetFinalGraphSummaryFlag
(
graph
);
// OptChildGraphs
auto
graph_order
=
GetGraphOrder
(
final_graph_id_
);
auto
&
graph_type
=
GetGraphOrderType
(
final_graph_id_
);
...
...
@@ -347,6 +360,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
auto
child_graph
=
GetGraph
(
graph_order
[
i
]);
CompileChildGraph
(
child_graph
);
}
GetSummaryNodes
(
graph
.
get
());
// merge child graph
MergeGraphExecOrder
();
}
else
{
...
...
@@ -768,6 +782,28 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
return
final_graph_id_
;
}
void
AscendSession
::
GetSummaryNodes
(
KernelGraph
*
graph
)
{
MS_LOG
(
DEBUG
)
<<
"Update summary Start"
;
MS_EXCEPTION_IF_NULL
(
graph
);
// if final graph have no child graph
auto
graph_order_iter
=
graph_execute_orders_
.
find
(
graph
->
graph_id
());
if
(
graph_order_iter
==
graph_execute_orders_
.
end
())
{
SessionBasic
::
GetSummaryNodes
(
graph
);
return
;
}
// for every child graph, find summary nodes
auto
summary
=
graph
->
summary_nodes
();
auto
graph_order
=
GetGraphOrder
(
graph
->
graph_id
());
for
(
size_t
i
=
0
;
i
<
graph_order
.
size
();
i
++
)
{
auto
child_graph
=
GetGraph
(
graph_order
[
i
]);
SessionBasic
::
GetSummaryNodes
(
child_graph
.
get
());
auto
child_graph_summary
=
child_graph
->
summary_nodes
();
summary
.
insert
(
child_graph_summary
.
begin
(),
child_graph_summary
.
end
());
}
graph
->
set_summary_nodes
(
summary
);
MS_LOG
(
DEBUG
)
<<
"Update summary end size: "
<<
summary
.
size
();
}
AnfNodePtr
AscendSession
::
CreateFakeOutput
(
GraphId
fake_graph_id
,
const
AnfNodePtr
&
true_output
)
{
auto
fake_graph
=
GetGraph
(
fake_graph_id
);
auto
output_item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
true_output
,
0
);
...
...
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
dcb90588
...
...
@@ -67,6 +67,7 @@ class AscendSession : public SessionBasic {
void
SetActive
(
GraphId
,
GraphId
)
override
;
// compile child graph when session have multiple child graphs
void
CompileChildGraph
(
const
KernelGraphPtr
&
child_graph
);
void
GetSummaryNodes
(
KernelGraph
*
graph
);
private:
void
InitRuntimeResource
();
...
...
@@ -149,6 +150,7 @@ class AscendSession : public SessionBasic {
AnfNodePtr
CreateFakeOutput
(
GraphId
final_graph_id
,
const
AnfNodePtr
&
true_output
);
// sync intial tensors' data to device
void
SyncInitialTenosrToDevice
();
void
SetFinalGraphSummaryFlag
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
);
// member variables
// key is final_graph_id,value is child graph execute order of final graph
...
...
mindspore/ccsrc/session/kernel_graph.h
浏览文件 @
dcb90588
...
...
@@ -40,6 +40,7 @@ class KernelGraph : public FuncGraph {
inputs_
=
std
::
make_shared
<
std
::
vector
<
AnfNodePtr
>>
();
execution_order_
=
{};
executable_
=
true
;
summary_node_exist_
=
false
;
stream_distinction_label_
=
kInvalidDistincLabel
;
}
~
KernelGraph
()
override
;
...
...
@@ -91,6 +92,10 @@ class KernelGraph : public FuncGraph {
bool
executable
()
const
{
return
executable_
;
}
// set executable of graph
void
set_executable
(
bool
executable
)
{
executable_
=
executable
;
}
// set summary_node of graph
void
set_summary_node_exist
(
bool
summary_node_exist
)
{
summary_node_exist_
=
summary_node_exist
;
}
// check whether exist summary node in graph
bool
summary_node_exist
()
const
{
return
summary_node_exist_
;
}
// set invalid inputs for control sink
std
::
vector
<
bool
>
*
MutableValidInputs
()
{
return
&
valid_inputs_
;
}
std
::
vector
<
bool
>
valid_inputs
()
const
{
return
valid_inputs_
;
}
...
...
@@ -133,6 +138,8 @@ class KernelGraph : public FuncGraph {
void
set_start_label
(
const
CNodePtr
&
start_label
)
{
start_label_
=
start_label
;
}
CNodePtr
get_start_label
()
{
return
start_label_
;
}
const
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
&
summary_nodes
()
const
{
return
summary_nodes_
;
}
void
set_summary_nodes
(
const
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
&
nodes
)
{
summary_nodes_
=
nodes
;
}
private:
// remove value node form graph
...
...
@@ -166,6 +173,9 @@ class KernelGraph : public FuncGraph {
// record map between ref final output anf with index and ref origin input with index
std
::
map
<
AnfWithOutIndex
,
AnfWithOutIndex
>
ref_out_in_map_
;
std
::
unordered_map
<
AnfNodePtr
,
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
size_t
>>>
node_output_edges_
;
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
summary_nodes_
;
// exist summary node in graph
bool
summary_node_exist_
;
// graph needn't execute
bool
executable_
;
// valid inputs
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
dcb90588
...
...
@@ -56,46 +56,6 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
return
py_param
.
ptr
();
}
void
GetSummaryNodes
(
const
KernelGraph
*
graph
,
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
*
summary
)
{
MS_LOG
(
DEBUG
)
<<
"Update summary Start"
;
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
summary
);
summary
->
clear
();
auto
apply_list
=
TopoSort
(
graph
->
get_return
());
for
(
auto
&
n
:
apply_list
)
{
MS_EXCEPTION_IF_NULL
(
n
);
if
(
IsPrimitiveCNode
(
n
,
prim
::
kPrimScalarSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimTensorSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimImageSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimHistogramSummary
))
{
auto
cnode
=
n
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
cnode
->
inputs
().
size
()
<=
kSummaryGetItem
)
{
MS_LOG
(
EXCEPTION
)
<<
"the node Summary should have 2 inputs at least!"
;
}
auto
node
=
cnode
->
input
(
kSummaryGetItem
);
MS_EXCEPTION_IF_NULL
(
node
);
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
node
,
0
);
if
(
!
AnfAlgo
::
IsRealKernel
(
item_with_index
.
first
))
{
MS_LOG
(
EXCEPTION
)
<<
"Unexpected node:"
<<
item_with_index
.
first
->
DebugString
();
}
(
*
summary
)[
n
->
fullname_with_scope
()]
=
item_with_index
;
}
}
MS_LOG
(
DEBUG
)
<<
"Update summary end size: "
<<
(
*
summary
).
size
();
}
bool
ExistSummaryNode
(
const
KernelGraph
*
graph
)
{
auto
ret
=
graph
->
get_return
();
MS_EXCEPTION_IF_NULL
(
ret
);
auto
all_nodes
=
DeepLinkedGraphSearch
(
ret
);
for
(
auto
&
n
:
all_nodes
)
{
if
(
IsPrimitiveCNode
(
n
,
prim
::
kPrimScalarSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimTensorSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimImageSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimHistogramSummary
))
{
return
true
;
}
}
return
false
;
}
BaseRef
CreateOneTensor
(
const
AnfNodePtr
&
node
,
size_t
output_index
,
const
KernelGraph
&
graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
{
MS_EXCEPTION_IF_NULL
(
node
);
...
...
@@ -332,6 +292,19 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
(
void
)
tab_str
.
append
(
any
.
ToString
());
MS_LOG
(
INFO
)
<<
tab_str
;
}
bool
ExistSummaryNode
(
const
KernelGraph
*
graph
)
{
auto
ret
=
graph
->
get_return
();
MS_EXCEPTION_IF_NULL
(
ret
);
auto
all_nodes
=
DeepLinkedGraphSearch
(
ret
);
for
(
auto
&
n
:
all_nodes
)
{
if
(
IsPrimitiveCNode
(
n
,
prim
::
kPrimScalarSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimTensorSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimImageSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimHistogramSummary
))
{
return
true
;
}
}
return
false
;
}
}
// namespace
GraphId
SessionBasic
::
graph_sum_
=
0
;
...
...
@@ -604,6 +577,9 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
graph
->
set_manager
(
manager
);
}
graph
->
SetExecOrderByDefault
();
if
(
ExistSummaryNode
(
graph
.
get
()))
{
graph
->
set_summary_node_exist
(
true
);
}
opt
::
BackendCommonOptimization
(
graph
);
return
graph
;
}
...
...
@@ -667,6 +643,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
graph
->
set_manager
(
manager
);
}
graph
->
SetExecOrderByDefault
();
if
(
ExistSummaryNode
(
graph
.
get
()))
{
graph
->
set_summary_node_exist
(
true
);
}
return
graph
;
}
...
...
@@ -760,6 +739,36 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
(
void
)
std
::
copy
(
all_opt_list
.
begin
(),
all_opt_list
.
end
(),
std
::
back_inserter
(
*
node_list
));
}
void
SessionBasic
::
GetSummaryNodes
(
KernelGraph
*
graph
)
{
MS_LOG
(
DEBUG
)
<<
"Update summary Start"
;
MS_EXCEPTION_IF_NULL
(
graph
);
if
(
!
graph
->
summary_node_exist
())
{
return
;
}
auto
summary
=
graph
->
summary_nodes
();
auto
apply_list
=
TopoSort
(
graph
->
get_return
());
for
(
auto
&
n
:
apply_list
)
{
MS_EXCEPTION_IF_NULL
(
n
);
if
(
IsPrimitiveCNode
(
n
,
prim
::
kPrimScalarSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimTensorSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimImageSummary
)
||
IsPrimitiveCNode
(
n
,
prim
::
kPrimHistogramSummary
))
{
auto
cnode
=
n
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
cnode
->
inputs
().
size
()
<=
kSummaryGetItem
)
{
MS_LOG
(
EXCEPTION
)
<<
"the node Summary should have 2 inputs at least!"
;
}
auto
node
=
cnode
->
input
(
kSummaryGetItem
);
MS_EXCEPTION_IF_NULL
(
node
);
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
node
,
0
,
true
);
if
(
!
AnfAlgo
::
IsRealKernel
(
item_with_index
.
first
))
{
MS_LOG
(
EXCEPTION
)
<<
"Unexpected node:"
<<
item_with_index
.
first
->
DebugString
();
}
summary
[
n
->
fullname_with_scope
()]
=
item_with_index
;
}
}
graph
->
set_summary_nodes
(
summary
);
MS_LOG
(
DEBUG
)
<<
"Update summary end size: "
<<
summary
.
size
();
}
void
SessionBasic
::
Summary
(
KernelGraph
*
graph
)
{
if
(
summary_callback_
==
nullptr
)
{
return
;
...
...
@@ -769,8 +778,12 @@ void SessionBasic::Summary(KernelGraph *graph) {
if
(
!
exist_summary
)
{
return
;
}
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
summary_outputs
;
GetSummaryNodes
(
graph
,
&
summary_outputs
);
GetSummaryNodes
(
graph
);
auto
summary_outputs
=
graph
->
summary_nodes
();
// do not exist summary node
if
(
summary_outputs
.
empty
())
{
return
;
}
std
::
map
<
std
::
string
,
tensor
::
TensorPtr
>
params_list
;
// fetch outputs apply kernel in session & run callback functions
for
(
auto
&
output_item
:
summary_outputs
)
{
...
...
mindspore/ccsrc/session/session_basic.h
浏览文件 @
dcb90588
...
...
@@ -92,6 +92,7 @@ class SessionBasic {
virtual
GraphId
GetGraphIdByNode
(
const
AnfNodePtr
&
)
const
{
return
kInvalidGraphId
;
}
virtual
GraphId
GetFinalRunGraph
()
const
{
return
kInvalidGraphId
;
}
virtual
void
SetActive
(
GraphId
,
GraphId
)
{}
virtual
void
GetSummaryNodes
(
KernelGraph
*
graph
);
protected:
virtual
void
LoadInputData
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录