Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
472d87fe
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
472d87fe
编写于
6月 03, 2020
作者:
K
kswang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize splitsort
上级
5c4731b7
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
82 addition
and
71 deletion
+82
-71
mindspore/ccsrc/vm/backend.cc
mindspore/ccsrc/vm/backend.cc
+20
-15
mindspore/ccsrc/vm/backend.h
mindspore/ccsrc/vm/backend.h
+4
-1
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+58
-54
mindspore/ccsrc/vm/transform.h
mindspore/ccsrc/vm/transform.h
+0
-1
未找到文件。
mindspore/ccsrc/vm/backend.cc
浏览文件 @
472d87fe
...
...
@@ -65,8 +65,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
result
.
outputs
=
outputs
;
result
.
graph_id
=
kInvalidGraphId
;
GraphId
graph_id
=
kInvalidGraphId
;
if
(
target
==
kCPUDevice
)
{
graph_id
=
cpu_sess_
->
CompileGraph
(
lst
,
outputs
);
if
(
target
!=
target_device_
&&
target
!=
""
)
{
CreateOtherSession
(
target
);
graph_id
=
other_sess_
->
CompileGraph
(
lst
,
outputs
);
}
else
{
graph_id
=
target_sess_
->
CompileGraph
(
lst
,
outputs
);
}
...
...
@@ -75,8 +76,8 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
MS_LOG
(
INFO
)
<<
"PrecompileOnly, stop run graph"
;
return
result
;
}
if
(
target
==
kCPUDevice
)
{
cpu
_sess_
->
BuildGraph
(
graph_id
);
if
(
target
!=
target_device_
&&
target
!=
""
)
{
other
_sess_
->
BuildGraph
(
graph_id
);
}
else
if
(
!
is_multi_graph_sink_
)
{
target_sess_
->
BuildGraph
(
graph_id
);
}
...
...
@@ -278,8 +279,8 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
VectorRef
outputs
;
// call ms rungraph (graphId, input ,output)
if
(
target
==
kCPUDevice
)
{
cpu
_sess_
->
RunGraph
(
g
,
inputs
,
&
outputs
);
if
(
target
!=
target_device_
&&
target
!=
""
)
{
other
_sess_
->
RunGraph
(
g
,
inputs
,
&
outputs
);
}
else
{
target_sess_
->
RunGraph
(
g
,
inputs
,
&
outputs
);
}
...
...
@@ -341,16 +342,20 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_
}
target_sess_
->
Init
(
device_id
);
target_sess_
->
RegisterSummaryCallBackFunc
(
callbacks
::
SummarySaveCallback
);
if
(
target
==
kCPUDevice
)
{
cpu_sess_
=
target_sess_
;
}
else
{
cpu_sess_
=
session
::
SessionFactory
::
Get
().
Create
(
kCPUDevice
);
if
(
cpu_sess_
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Create cpu session failed with target "
<<
target
<<
"."
;
}
cpu_sess_
->
Init
(
0
);
cpu_sess_
->
RegisterSummaryCallBackFunc
(
callbacks
::
SummarySaveCallback
);
target_device_
=
target
;
}
void
MsBackend
::
CreateOtherSession
(
const
std
::
string
&
target
)
{
if
(
other_sess_
!=
nullptr
&&
other_device_
==
target
)
{
return
;
}
other_sess_
=
session
::
SessionFactory
::
Get
().
Create
(
kCPUDevice
);
if
(
other_sess_
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Session create failed!, please make sure target device:"
<<
target
<<
" is available."
;
}
other_sess_
->
Init
(
0
);
other_sess_
->
RegisterSummaryCallBackFunc
(
callbacks
::
SummarySaveCallback
);
other_device_
=
target
;
}
GraphId
MsBackend
::
CompileGraph
(
NotNull
<
FuncGraphPtr
>
fg
)
{
return
target_sess_
->
CompileGraph
(
fg
);
}
...
...
mindspore/ccsrc/vm/backend.h
浏览文件 @
472d87fe
...
...
@@ -107,10 +107,13 @@ class MsBackend : public Backend {
LinConvertResult
GetMultiGraphRun
(
const
FuncGraphPtr
&
g
)
override
;
GraphId
CompileGraph
(
NotNull
<
FuncGraphPtr
>
fg
)
override
;
VectorRef
RunGraph
(
GraphId
graph_id
,
const
VectorRef
&
args
);
void
CreateOtherSession
(
const
std
::
string
&
target
);
private:
session
::
SessionPtr
target_sess_
;
session
::
SessionPtr
cpu_sess_
;
session
::
SessionPtr
other_sess_
;
std
::
string
target_device_
;
std
::
string
other_device_
;
std
::
unordered_map
<
BaseRef
,
CondGraph
,
BaseRefHash
>
simu_cond_map_
;
std
::
unordered_map
<
GraphId
,
LinConvertResult
>
graph_id_map_
;
std
::
unordered_map
<
BaseRef
,
std
::
list
<
std
::
pair
<
GraphId
,
VectorRef
>>
,
BaseRefHash
>
graph_inputs_
;
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
472d87fe
...
...
@@ -21,6 +21,7 @@
#include <algorithm>
#include <map>
#include <queue>
#include <stack>
#include <set>
#include <string>
#include <vector>
...
...
@@ -75,7 +76,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
return
default_target
;
}
auto
primitive
=
value
->
cast
<
PrimitivePtr
>
();
ValuePtr
att_target
=
primitive
->
GetAttr
(
"target"
);
ValuePtr
att_target
=
primitive
->
GetAttr
(
"
primitive_
target"
);
if
(
att_target
!=
nullptr
)
{
std
::
string
target
=
GetValue
<
std
::
string
>
(
att_target
);
return
target
;
...
...
@@ -127,6 +128,58 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
}
}
}
std
::
vector
<
AnfNodePtr
>
SplitSort
(
const
FuncGraphPtr
&
graph
,
const
std
::
string
&
default_target
)
{
std
::
vector
<
AnfNodePtr
>
result
;
std
::
stack
<
AnfNodePtr
>
to_visit
;
std
::
stack
<
AnfNodePtr
>
next_to_visit
;
std
::
map
<
AnfNodePtr
,
size_t
>
nodes_ref
;
CalcNodeRefCount
(
graph
,
&
nodes_ref
);
std
::
string
handle_target
=
default_target
;
std
::
string
next_target
=
""
;
to_visit
.
push
(
graph
->
get_return
());
while
(
!
to_visit
.
empty
()
||
!
next_to_visit
.
empty
())
{
if
(
to_visit
.
empty
())
{
to_visit
.
swap
(
next_to_visit
);
handle_target
=
next_target
;
}
auto
&
node
=
to_visit
.
top
();
to_visit
.
pop
();
MS_EXCEPTION_IF_NULL
(
node
);
result
.
emplace_back
(
node
);
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
node_inputs
=
cnode
->
inputs
();
std
::
reverse
(
node_inputs
.
begin
(),
node_inputs
.
end
());
for
(
auto
&
input
:
node_inputs
)
{
auto
iter
=
nodes_ref
.
find
(
input
);
if
(
iter
!=
nodes_ref
.
end
())
{
iter
->
second
--
;
if
(
iter
->
second
!=
0
)
{
continue
;
}
}
if
(
!
input
->
isa
<
CNode
>
())
{
to_visit
.
push
(
input
);
continue
;
}
std
::
string
input_target
=
GetCNodeTarget
(
input
);
if
(
input_target
==
handle_target
)
{
to_visit
.
push
(
input
);
}
else
if
(
next_to_visit
.
empty
()
||
input_target
==
next_target
)
{
next_to_visit
.
push
(
input
);
next_target
=
input_target
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"only support two different target"
;
}
}
}
std
::
reverse
(
result
.
begin
(),
result
.
end
());
return
result
;
}
}
// namespace
CompileGraph
::
CompileGraph
(
const
BackendPtr
&
backend
,
const
std
::
vector
<
PrimitivePtr
>
&
cut_list
)
...
...
@@ -180,65 +233,16 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
return
false
;
}
std
::
vector
<
AnfNodePtr
>
CompileGraph
::
SplitSort
(
const
FuncGraphPtr
&
graph
)
{
std
::
vector
<
AnfNodePtr
>
result
;
std
::
queue
<
AnfNodePtr
>
queue
;
std
::
queue
<
AnfNodePtr
>
next_queue
;
std
::
map
<
AnfNodePtr
,
size_t
>
nodes_ref
;
CalcNodeRefCount
(
graph
,
&
nodes_ref
);
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
std
::
string
queue_target
=
context_ptr
->
device_target
();
std
::
string
next_target
=
""
;
queue
.
push
(
graph
->
get_return
());
while
(
!
queue
.
empty
()
||
!
next_queue
.
empty
())
{
if
(
queue
.
empty
())
{
queue
.
swap
(
next_queue
);
queue_target
=
next_target
;
}
auto
&
node
=
queue
.
front
();
queue
.
pop
();
MS_EXCEPTION_IF_NULL
(
node
);
result
.
emplace_back
(
node
);
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
for
(
auto
&
input
:
cnode
->
inputs
())
{
auto
iter
=
nodes_ref
.
find
(
input
);
if
(
iter
!=
nodes_ref
.
end
())
{
iter
->
second
--
;
if
(
iter
->
second
!=
0
)
{
continue
;
}
}
if
(
!
input
->
isa
<
CNode
>
())
{
queue
.
push
(
input
);
continue
;
}
std
::
string
input_target
=
GetCNodeTarget
(
input
);
if
(
input_target
==
queue_target
)
{
queue
.
push
(
input
);
}
else
if
(
next_queue
.
empty
()
||
input_target
==
next_target
)
{
next_queue
.
push
(
input
);
next_target
=
input_target
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"only support two different target"
;
}
}
}
std
::
reverse
(
result
.
begin
(),
result
.
end
());
return
result
;
}
VectorRef
CompileGraph
::
SplitNodes
(
const
FuncGraphPtr
&
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
VectorRef
splits
;
VectorRef
split
;
auto
nodes
=
TopoSort
(
graph
->
get_return
());
if
(
ContainMultiTarget
(
nodes
))
{
nodes
=
SplitSort
(
graph
);
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
std
::
string
default_target
=
context_ptr
->
device_target
();
nodes
=
SplitSort
(
graph
,
default_target
);
}
std
::
string
last_target
;
MS_LOG
(
DEBUG
)
<<
"Split all nodes size:"
<<
nodes
.
size
();
...
...
mindspore/ccsrc/vm/transform.h
浏览文件 @
472d87fe
...
...
@@ -79,7 +79,6 @@ class CompileGraph {
private:
void
PushParameters
(
const
FuncGraphPtr
&
func_graph
);
std
::
vector
<
AnfNodePtr
>
SplitSort
(
const
FuncGraphPtr
&
graph
);
bool
SplitGraph
(
const
FuncGraphPtr
&
func_graph
);
int
LinConvert
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtrList
&
node_list
,
const
std
::
string
&
target
=
""
);
int
InterpretNode
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录