Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1e8997f4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1e8997f4
编写于
4月 23, 2020
作者:
K
kswang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize sort for mem reuse and fix memreuse bug
上级
c984c48f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
76 addition
and
81 deletion
+76
-81
mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc
...spore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc
+2
-1
mindspore/ccsrc/session/anf_runtime_algorithm.cc
mindspore/ccsrc/session/anf_runtime_algorithm.cc
+5
-10
mindspore/ccsrc/session/anf_runtime_algorithm.h
mindspore/ccsrc/session/anf_runtime_algorithm.h
+0
-1
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+66
-67
mindspore/ccsrc/session/kernel_graph.h
mindspore/ccsrc/session/kernel_graph.h
+3
-2
未找到文件。
mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc
浏览文件 @
1e8997f4
...
...
@@ -251,9 +251,10 @@ void BestFitMemReuse::ReleaseNodeUnusedOutput(const KernelDef *kernel_def_ptr) {
}
size_t
BestFitMemReuse
::
FindIndx
(
const
std
::
vector
<
MembufPtr
>
&
membuf_ptr_list
,
int
fac_idx
)
const
{
size_t
membuf_index
=
0
;
size_t
membuf_index
=
membuf_ptr_list
.
size
()
;
for
(
size_t
n
=
0
;
n
<
membuf_ptr_list
.
size
();
++
n
)
{
auto
membuf
=
membuf_ptr_list
[
n
];
MS_EXCEPTION_IF_NULL
(
membuf
);
if
(
membuf
->
index_
==
fac_idx
)
{
membuf_index
=
n
;
break
;
...
...
mindspore/ccsrc/session/anf_runtime_algorithm.cc
浏览文件 @
1e8997f4
...
...
@@ -851,17 +851,12 @@ void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &i
bool
AnfRuntimeAlgorithm
::
IsCommunicationOp
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_name
=
AnfAlgo
::
GetCNodeName
(
node
);
auto
kernel_type
=
AnfAlgo
::
GetKernelType
(
node
);
if
(
kernel_name
==
kAllReduceOpName
||
kernel_type
==
HCCL_KERNEL
)
{
return
true
;
if
(
!
node
->
isa
<
CNode
>
())
{
return
false
;
}
return
false
;
}
bool
AnfRuntimeAlgorithm
::
IsAllReduceOp
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
node
)
==
kAllReduceOpName
)
{
auto
kernel_name
=
AnfAlgo
::
GetCNodeName
(
node
);
if
(
kernel_name
==
kAllReduceOpName
||
kernel_name
==
kAllGatherOpName
||
kernel_name
==
kBroadcastOpName
||
kernel_name
==
kReduceScatterOpName
)
{
return
true
;
}
return
false
;
...
...
mindspore/ccsrc/session/anf_runtime_algorithm.h
浏览文件 @
1e8997f4
...
...
@@ -176,7 +176,6 @@ class AnfRuntimeAlgorithm {
// get real input index for some tbe ops which input order is different between me and tbe impl
static
size_t
GetRealInputIndex
(
const
AnfNodePtr
&
anf_node
,
const
size_t
cur_index
);
static
bool
IsCommunicationOp
(
const
AnfNodePtr
&
node
);
static
bool
IsAllReduceOp
(
const
AnfNodePtr
&
node
);
static
bool
IsGetNext
(
const
NotNull
<
AnfNodePtr
>
&
node
);
};
}
// namespace session
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
1e8997f4
...
...
@@ -49,80 +49,81 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
return
std
::
vector
<
AnfNodePtr
>
();
}
void
KernelGraph
::
SetExecOrderByDefault
()
{
std
::
stack
<
AnfNodePtr
>
seed_nodes
;
UpdateNodeEdgeList
(
&
seed_nodes
);
execution_order_
.
clear
();
std
::
unordered_set
<
AnfNodePtr
>
visited_nodes
;
std
::
queue
<
AnfNodePtr
>
zero_input_nodes
;
auto
visit_node_descendant
=
[
&
visited_nodes
,
this
](
const
AnfNodePtr
&
node
,
std
::
queue
<
AnfNodePtr
>
*
visit_queue
)
{
auto
it
=
node_output_edges_
.
find
(
node
);
if
(
it
==
node_output_edges_
.
end
())
{
// value node and parameter has no input,no need to print log
if
(
node
->
isa
<
CNode
>
())
{
MS_LOG
(
DEBUG
)
<<
"Can not find node ["
<<
node
->
DebugString
()
<<
"]"
;
}
return
;
void
KernelGraph
::
VisitNodeDescendants
(
const
AnfNodePtr
&
node
,
std
::
queue
<
AnfNodePtr
>
*
visit_queue
,
std
::
unordered_set
<
AnfNodePtr
>
*
visited_nodes
)
{
MS_EXCEPTION_IF_NULL
(
visit_queue
);
MS_EXCEPTION_IF_NULL
(
visited_nodes
);
auto
it
=
node_output_edges_
.
find
(
node
);
if
(
it
==
node_output_edges_
.
end
())
{
// value node and parameter has no input,no need to print log
if
(
node
->
isa
<
CNode
>
())
{
MS_LOG
(
DEBUG
)
<<
"Can not find node ["
<<
node
->
DebugString
()
<<
"]"
;
}
return
;
}
// visit all reduce node first, then other nodes
std
::
vector
<
AnfNodePtr
>
active_nodes
;
for
(
const
auto
&
output_edge
:
it
->
second
)
{
auto
next_node
=
output_edge
.
first
;
if
(
node_input_num_
.
find
(
next_node
)
==
node_input_num_
.
end
())
{
MS_EXCEPTION_IF_NULL
(
next_node
);
MS_LOG
(
EXCEPTION
)
<<
"Can't find node["
<<
next_node
->
DebugString
()
<<
"]"
;
}
// visit all reduce node first, then other nodes
std
::
vector
<
AnfNodePtr
>
active_nodes
;
for
(
const
auto
&
output_edge
:
it
->
second
)
{
auto
next_node
=
output_edge
.
first
;
if
(
node_input_num_
.
find
(
next_node
)
==
node_input_num_
.
end
())
{
MS_EXCEPTION_IF_NULL
(
next_node
);
MS_LOG
(
DEBUG
)
<<
"Decrease input:"
<<
next_node
->
DebugString
()
<<
",node:"
<<
node
->
DebugString
()
<<
",num: "
<<
node_input_num_
[
next_node
]
<<
",decrease num:"
<<
output_edge
.
second
;
if
(
node_input_num_
[
next_node
]
<
output_edge
.
second
)
{
MS_LOG
(
EXCEPTION
)
<<
"Input node:"
<<
next_node
->
DebugString
()
<<
",node_output_num"
<<
node_input_num_
[
next_node
]
<<
",depend edge:"
<<
output_edge
.
second
;
}
node_input_num_
[
next_node
]
=
node_input_num_
[
next_node
]
-
output_edge
.
second
;
// allreduce first
if
(
node_input_num_
[
next_node
]
==
0
&&
visited_nodes
.
find
(
next_node
)
==
visited_nodes
.
end
())
{
(
void
)
visited_nodes
.
insert
(
next_node
);
if
(
AnfAlgo
::
IsAllReduceOp
(
next_node
))
{
MS_LOG
(
DEBUG
)
<<
"visit node:"
<<
next_node
->
DebugString
();
visit_queue
->
push
(
next_node
);
}
else
{
active_nodes
.
emplace_back
(
next_node
);
}
MS_LOG
(
EXCEPTION
)
<<
"Can't find node["
<<
next_node
->
DebugString
()
<<
"]"
;
}
MS_EXCEPTION_IF_NULL
(
next_node
);
MS_LOG
(
DEBUG
)
<<
"Decrease input:"
<<
next_node
->
DebugString
()
<<
",node:"
<<
node
->
DebugString
()
<<
",num: "
<<
node_input_num_
[
next_node
]
<<
",decrease num:"
<<
output_edge
.
second
;
if
(
node_input_num_
[
next_node
]
<
output_edge
.
second
)
{
MS_LOG
(
EXCEPTION
)
<<
"Input node:"
<<
next_node
->
DebugString
()
<<
",node_output_num"
<<
node_input_num_
[
next_node
]
<<
",depend edge:"
<<
output_edge
.
second
;
}
node_input_num_
[
next_node
]
=
node_input_num_
[
next_node
]
-
output_edge
.
second
;
// allreduce first
if
(
node_input_num_
[
next_node
]
==
0
&&
visited_nodes
->
find
(
next_node
)
==
visited_nodes
->
end
())
{
(
void
)
visited_nodes
->
insert
(
next_node
);
if
(
AnfAlgo
::
IsCommunicationOp
(
next_node
))
{
MS_LOG
(
DEBUG
)
<<
"visit node:"
<<
next_node
->
DebugString
();
visit_queue
->
push
(
next_node
);
}
else
{
active_nodes
.
emplace_back
(
next_node
);
}
}
}
for
(
auto
&
node
:
active_nodes
)
{
MS_LOG
(
DEBUG
)
<<
"visit node:"
<<
node
->
DebugString
();
visit_queue
->
push
(
node
);
}
};
for
(
auto
&
node
:
active_nodes
)
{
MS_LOG
(
DEBUG
)
<<
"visit node:"
<<
node
->
DebugString
();
visit_queue
->
push
(
node
);
}
}
AnfNodePtr
last_allreduce_node
=
nullptr
;
std
::
queue
<
AnfNodePtr
>
allreduce_descendants
;
while
(
!
seed_nodes
.
empty
()
||
last_allreduce_node
!=
nullptr
)
{
void
KernelGraph
::
SetExecOrderByDefault
()
{
std
::
queue
<
AnfNodePtr
>
seed_nodes
;
UpdateNodeEdgeList
(
&
seed_nodes
);
execution_order_
.
clear
();
std
::
unordered_set
<
AnfNodePtr
>
visited_nodes
;
std
::
queue
<
AnfNodePtr
>
zero_input_nodes
;
AnfNodePtr
last_communication_node
=
nullptr
;
std
::
queue
<
AnfNodePtr
>
communication_descendants
;
while
(
!
seed_nodes
.
empty
()
||
last_communication_node
!=
nullptr
)
{
// seed nodes first, then visit last all reduce node descendant
if
(
seed_nodes
.
empty
())
{
visit_node_descendant
(
last_allreduce_node
,
&
allreduce_descendant
s
);
last_
allreduce
_node
=
nullptr
;
VisitNodeDescendants
(
last_communication_node
,
&
communication_descendants
,
&
visited_node
s
);
last_
communication
_node
=
nullptr
;
}
else
{
zero_input_nodes
.
push
(
seed_nodes
.
top
());
zero_input_nodes
.
push
(
seed_nodes
.
front
());
seed_nodes
.
pop
();
}
// all reduce node descendant first, then common queue
while
(
!
zero_input_nodes
.
empty
()
||
!
allreduce
_descendants
.
empty
())
{
while
(
!
zero_input_nodes
.
empty
()
||
!
communication
_descendants
.
empty
())
{
AnfNodePtr
node
=
nullptr
;
bool
is_
allreduce
_descendant
=
false
;
if
(
allreduce
_descendants
.
empty
())
{
bool
is_
communication
_descendant
=
false
;
if
(
communication
_descendants
.
empty
())
{
node
=
zero_input_nodes
.
front
();
zero_input_nodes
.
pop
();
}
else
{
node
=
allreduce
_descendants
.
front
();
allreduce
_descendants
.
pop
();
is_
allreduce
_descendant
=
true
;
node
=
communication
_descendants
.
front
();
communication
_descendants
.
pop
();
is_
communication
_descendant
=
true
;
}
// add execute node
MS_EXCEPTION_IF_NULL
(
node
);
...
...
@@ -130,19 +131,18 @@ void KernelGraph::SetExecOrderByDefault() {
execution_order_
.
push_back
(
node
->
cast
<
CNodePtr
>
());
}
// for all reduce node, visit last all reduce node descendant
if
(
AnfAlgo
::
Is
AllReduce
Op
(
node
))
{
if
(
last_
allreduce
_node
!=
nullptr
)
{
visit_node_descendant
(
last_allreduce_node
,
&
allreduce_descendant
s
);
if
(
AnfAlgo
::
Is
Communication
Op
(
node
))
{
if
(
last_
communication
_node
!=
nullptr
)
{
VisitNodeDescendants
(
last_communication_node
,
&
communication_descendants
,
&
visited_node
s
);
}
last_
allreduce
_node
=
node
;
}
else
if
(
is_
allreduce
_descendant
)
{
visit_node_descendant
(
node
,
&
allreduce_descendant
s
);
last_
communication
_node
=
node
;
}
else
if
(
is_
communication
_descendant
)
{
VisitNodeDescendants
(
node
,
&
communication_descendants
,
&
visited_node
s
);
}
else
{
visit_node_descendant
(
node
,
&
zero_input
_nodes
);
VisitNodeDescendants
(
node
,
&
zero_input_nodes
,
&
visited
_nodes
);
}
}
}
CheckLoop
();
}
...
...
@@ -467,7 +467,7 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf
return
true
;
}
void
KernelGraph
::
UpdateNodeEdgeList
(
std
::
stack
<
AnfNodePtr
>
*
seed_nodes
)
{
void
KernelGraph
::
UpdateNodeEdgeList
(
std
::
queue
<
AnfNodePtr
>
*
seed_nodes
)
{
node_output_edges_
.
clear
();
node_input_num_
.
clear
();
node_input_edges_
.
clear
();
...
...
@@ -483,7 +483,6 @@ void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
seed_nodes
->
push
(
node
);
continue
;
}
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
}
...
...
mindspore/ccsrc/session/kernel_graph.h
浏览文件 @
1e8997f4
...
...
@@ -22,7 +22,6 @@
#include <utility>
#include <string>
#include <queue>
#include <stack>
#include <map>
#include <unordered_set>
#include "ir/func_graph.h"
...
...
@@ -94,8 +93,10 @@ class KernelGraph : public FuncGraph {
private:
// remove value node form graph
bool
RemoveValueNodeFromGraph
(
const
ValueNodePtr
&
value_node
);
void
VisitNodeDescendants
(
const
AnfNodePtr
&
node
,
std
::
queue
<
AnfNodePtr
>
*
visit_queue
,
std
::
unordered_set
<
AnfNodePtr
>
*
visited_nodes
);
// update node edge list
void
UpdateNodeEdgeList
(
std
::
stack
<
AnfNodePtr
>
*
seed_nodes
);
void
UpdateNodeEdgeList
(
std
::
queue
<
AnfNodePtr
>
*
seed_nodes
);
// add node depend edge by data edge or control depend
void
AddDependEdge
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
input
,
size_t
depend_edge_num
);
// handle control depend
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录