Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c20cd122
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c20cd122
编写于
7月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3110 Delete deprecated codes of ascend control flow
Merge pull request !3110 from zhoufeng/delete-deprecated-codes
上级
402378a6
2943cb1c
变更
15
展开全部
隐藏空白更改
内联
并排
Showing
15 changed file
with
30 addition
and
1584 deletion
+30
-1584
mindspore/ccsrc/backend/session/ascend_session.cc
mindspore/ccsrc/backend/session/ascend_session.cc
+11
-934
mindspore/ccsrc/backend/session/ascend_session.h
mindspore/ccsrc/backend/session/ascend_session.h
+6
-64
mindspore/ccsrc/backend/session/cpu_session.cc
mindspore/ccsrc/backend/session/cpu_session.cc
+1
-1
mindspore/ccsrc/backend/session/gpu_session.cc
mindspore/ccsrc/backend/session/gpu_session.cc
+1
-1
mindspore/ccsrc/backend/session/kernel_graph.cc
mindspore/ccsrc/backend/session/kernel_graph.cc
+0
-71
mindspore/ccsrc/backend/session/kernel_graph.h
mindspore/ccsrc/backend/session/kernel_graph.h
+0
-14
mindspore/ccsrc/backend/session/session_basic.cc
mindspore/ccsrc/backend/session/session_basic.cc
+2
-2
mindspore/ccsrc/backend/session/session_basic.h
mindspore/ccsrc/backend/session/session_basic.h
+1
-10
mindspore/ccsrc/pipeline/pynative/CMakeLists.txt
mindspore/ccsrc/pipeline/pynative/CMakeLists.txt
+1
-1
mindspore/ccsrc/vm/backend.cc
mindspore/ccsrc/vm/backend.cc
+0
-187
mindspore/ccsrc/vm/backend.h
mindspore/ccsrc/vm/backend.h
+1
-43
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+4
-76
mindspore/ccsrc/vm/transform.h
mindspore/ccsrc/vm/transform.h
+0
-3
mindspore/ccsrc/vm/vm.cc
mindspore/ccsrc/vm/vm.cc
+2
-167
mindspore/ccsrc/vm/vm.h
mindspore/ccsrc/vm/vm.h
+0
-10
未找到文件。
mindspore/ccsrc/backend/session/ascend_session.cc
浏览文件 @
c20cd122
此差异已折叠。
点击以展开。
mindspore/ccsrc/backend/session/ascend_session.h
浏览文件 @
c20cd122
...
...
@@ -51,26 +51,16 @@ class AscendSession : public SessionBasic {
py
::
tuple
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
override
;
// set parameters of final graph
GraphId
SetFinalGraphInput
(
const
std
::
vector
<
AnfNodePtr
>
&
args
)
override
;
// set output of final graph
void
SetFinalGraphOutput
(
const
BaseRef
&
output
)
override
;
// insert switch and set the relative active ops
void
SwitchCompile
(
GraphId
cond_g
,
GraphId
true_g
,
GraphId
false_g
,
const
AnfNodePtr
&
condition_output
)
override
;
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
void
SetChildGraphInput
(
GraphId
g
,
const
VectorRef
&
args
)
override
;
// get graph id in child graphs by ME front anf node pointer
GraphId
GetGraphIdByNode
(
const
AnfNodePtr
&
front_anf
)
const
override
;
// get graph id of final graph
GraphId
GetFinalRunGraph
()
const
override
{
return
final_graph_id_
;
}
// insert active to graph
void
SetActive
(
GraphId
,
GraphId
)
override
;
// compile child graph when session have multiple child graphs
void
CompileChildGraph
(
const
KernelGraphPtr
&
child_graph
);
void
RecurseGetSummaryNodes
(
KernelGraph
*
graph
,
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
*
summary
);
void
GetSummaryNodes
(
KernelGraph
*
graph
);
private:
void
RecurseSetSummaryNodes
(
KernelGraph
*
graph
,
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
*
summary
);
void
SetSummaryNodes
(
KernelGraph
*
graph
)
override
;
void
InitRuntimeResource
();
void
SelectKernel
(
const
KernelGraph
&
kernel_graph
)
const
;
void
HardwareOptimize
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
...
...
@@ -92,63 +82,21 @@ class AscendSession : public SessionBasic {
void
RunOpHardwareOptimize
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
const
;
void
RunOpExecTask
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
size_t
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
size_t
input_index
);
size_t
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
ValuePtr
&
value
,
size_t
input_index
);
size_t
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
VectorRef
&
vec_args
,
size_t
input_index
);
void
SetFinalGraphOutput
(
const
AnfNodePtr
&
node
);
void
SetFinalGraphOutput
(
const
ValuePtr
&
value
);
void
SetFinalGraphOutput
(
const
VectorRef
&
vec_output
);
void
SplitGraph
(
NotNull
<
KernelGraphPtr
>
graph
,
const
std
::
set
<
PrimitivePtr
>
&
cut_prims
,
const
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
// split graphs with recurse from root graph
void
SplitGraphs
(
NotNull
<
KernelGraphPtr
>
root_graph
);
void
BackendOptimization
(
const
std
::
vector
<
KernelGraphPtr
>
&
all_graphs
);
void
LinkChildGraphs
(
NotNull
<
KernelGraphPtr
>
graph
);
static
void
BackendOptimization
(
const
std
::
vector
<
KernelGraphPtr
>
&
all_graphs
);
static
void
LinkChildGraphs
(
NotNull
<
KernelGraphPtr
>
graph
);
void
RootGraphExecutorValidate
(
NotNull
<
KernelGraphPtr
>
graph
);
std
::
vector
<
AnfNodePtr
>
ConstructSplitedGraph
(
const
KernelGraphPtr
&
new_kernel_graph
,
const
std
::
vector
<
CNodePtr
>
&
list
);
void
RecurseCompileGraph
(
NotNull
<
KernelGraphPtr
>
graph
,
const
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
void
RecurseSplitGraph
(
NotNull
<
KernelGraphPtr
>
graph
,
const
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
AnfNodePtr
BindNewCallToNewGraph
(
NotNull
<
KernelGraphPtr
>
graph
,
const
std
::
vector
<
CNodePtr
>
&
child_graph_list
);
// merge execution order list of child graphs
void
MergeGraphExecOrder
();
// insert assion op to sync data bettween different graphs
void
InsertAssignToGraph
(
GraphId
graph_id
,
const
AnfNodePtr
&
from
,
const
AnfNodePtr
&
to
);
// insert mutiple assigns to graph
void
InsertMultipleAssignToGraph
(
GraphId
graph_id
,
const
AnfNodePtr
&
from
,
const
AnfNodePtr
&
to
);
// insert active op to graph
void
InsertStreamActiveToGraph
(
GraphId
graph_id
,
uint32_t
actived_stream
);
// get execute index of graph
size_t
ExecOrderOfChildGraph
(
GraphId
final_graph
,
GraphId
child_graph
);
// handle condition graph from vm
void
InsertSwitchToGraph
(
GraphId
condition_graph_id
,
GraphId
true_graph_id
);
// insert depend to graph, used to attch control nodes to graph
void
InsertDependToGraph
(
GraphId
graph_id
,
const
AnfNodePtr
&
attch_node
);
// insert depend to graph, used to attch control nodes to graph
void
InsertControlDependToGraph
(
GraphId
graph_id
,
const
AnfNodePtr
&
first_node
,
const
AnfNodePtr
&
second_node
);
// set child graph parameter if front arg is a anf
void
SetChildGraphParameter
(
const
AnfNodePtr
&
front_anf
,
GraphId
to_graph_id
,
size_t
input_idx
);
// set child graph parameter if front arg is a tensor
void
SetChildGraphParameter
(
const
tensor
::
TensorPtr
&
front_tensor
,
GraphId
to_graph_id
,
size_t
input_idx
);
// update the execution order of all child graphs
void
UpdateGraphOrder
(
GraphId
to_graph
);
// handle switch when merge
void
MergeSwitchCompile
();
// get graph order vector by graph id
std
::
vector
<
GraphId
>
&
GetGraphOrder
(
GraphId
final_graph_id
)
;
const
std
::
vector
<
GraphId
>
&
GetGraphOrder
(
GraphId
final_graph_id
)
const
;
// get graph order type vector by graph id
std
::
vector
<
GraphType
>
&
GetGraphOrderType
(
GraphId
final_graph_id
);
// copy output of if and else
void
CopyOutputOfIf
(
GraphId
false_graph_id
);
const
std
::
vector
<
GraphType
>
&
GetGraphOrderType
(
GraphId
final_graph_id
)
const
;
// check if graph cache exist
bool
GraphCacheExist
(
const
GraphInfo
&
graph_info
)
const
;
// insert all assign to child graph
void
InsertAllAssigns
();
// create fake output of final graph
AnfNodePtr
CreateFakeOutput
(
GraphId
final_graph_id
,
const
AnfNodePtr
&
true_output
);
// sync intial tensors' data to device
void
SyncInitialTenosrToDevice
();
void
SetFinalGraphSummaryFlag
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
);
...
...
@@ -162,16 +110,10 @@ class AscendSession : public SessionBasic {
void
AssignStaticMemory
(
const
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
const
;
void
UpdateRefOutputMap
(
const
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
const
;
// member variables
// key is final_graph_id,value is child graph execute order of final graph
std
::
unordered_map
<
GraphId
,
std
::
vector
<
GraphId
>>
graph_execute_orders_
;
// key is final_graph_id,value is the graph types of child graphs
std
::
unordered_map
<
GraphId
,
std
::
vector
<
GraphType
>>
graph_order_types_
;
// record condition graph of while
std
::
unordered_map
<
GraphId
,
GraphId
>
while_condition_graphs_
;
// record all conditions
std
::
unordered_map
<
GraphId
,
std
::
pair
<
GraphId
,
GraphId
>>
switches_
;
std
::
unordered_map
<
GraphId
,
AnfNodePtr
>
condition_output_
;
// share parameters
std
::
vector
<
std
::
tuple
<
AnfNodePtr
,
GraphId
,
size_t
>>
assigns_
;
// initial tensors, these tensor will sync data to device before run graph
...
...
mindspore/ccsrc/backend/session/cpu_session.cc
浏览文件 @
c20cd122
...
...
@@ -108,7 +108,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
kernel_graph
->
set_execution_order
(
execution_order
);
NamedSummaryOutputs
summary_outputs
;
if
(
enable_summary
)
{
G
etSummaryNodes
(
kernel_graph
.
get
());
S
etSummaryNodes
(
kernel_graph
.
get
());
summary_outputs
=
kernel_graph
->
summary_nodes
();
runtime_
.
IncreaseSummaryRefCount
(
summary_outputs
);
}
...
...
mindspore/ccsrc/backend/session/gpu_session.cc
浏览文件 @
c20cd122
...
...
@@ -217,7 +217,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
Reorder
(
&
execution_order
);
graph
->
set_execution_order
(
execution_order
);
// Get summary nodes.
G
etSummaryNodes
(
graph
.
get
());
S
etSummaryNodes
(
graph
.
get
());
// Remove NoOp from execution graph
opt
::
RemoveNopNode
(
graph
.
get
());
// Set graph manager.
...
...
mindspore/ccsrc/backend/session/kernel_graph.cc
浏览文件 @
c20cd122
...
...
@@ -898,27 +898,6 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
std
::
queue
<
AnfNodePtr
>
seed_nodes
;
UpdateNodeEdgeList
(
&
seed_nodes
);
}
// update graph inputs in child graph
auto
it_real_inputs
=
std
::
find_if
(
real_inputs_
.
begin
(),
real_inputs_
.
end
(),
[
&
old_anf_node
](
const
std
::
pair
<
AnfNodePtr
,
std
::
vector
<
AnfNodePtr
>>
&
n
)
->
bool
{
return
n
.
first
==
old_anf_node
.
get
();
});
if
(
it_real_inputs
!=
real_inputs_
.
end
())
{
// erase old parameter in map
auto
old_args
=
it_real_inputs
->
second
;
real_inputs_
.
erase
(
it_real_inputs
);
// insert new parameter to map
auto
iter
=
std
::
find_if
(
real_inputs_
.
begin
(),
real_inputs_
.
end
(),
[
&
new_anf_node
](
const
std
::
pair
<
AnfNodePtr
,
std
::
vector
<
AnfNodePtr
>>
&
n
)
->
bool
{
return
n
.
first
==
new_anf_node
.
get
();
});
if
(
iter
!=
real_inputs_
.
end
())
{
MS_LOG
(
WARNING
)
<<
new_anf_node
->
DebugString
()
<<
" Already exist in real inputs, will be rewrited."
;
iter
->
second
=
old_args
;
}
else
{
real_inputs_
.
emplace_back
(
new_anf_node
,
old_args
);
}
}
}
void
KernelGraph
::
UpdateExecuteKernelStreamLabel
()
{
...
...
@@ -953,56 +932,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
return
result
;
}
void
KernelGraph
::
SetRealInput
(
const
AnfNodePtr
&
parameter
,
const
AnfNodePtr
&
arg
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
MS_EXCEPTION_IF_NULL
(
arg
);
MS_LOG
(
INFO
)
<<
"Parameter: "
<<
parameter
->
DebugString
()
<<
", real input : "
<<
arg
->
DebugString
();
MS_EXCEPTION_IF_NULL
(
parameter
);
MS_EXCEPTION_IF_NULL
(
arg
);
auto
iter
=
std
::
find_if
(
real_inputs_
.
begin
(),
real_inputs_
.
end
(),
[
&
parameter
](
const
std
::
pair
<
AnfNodePtr
,
std
::
vector
<
AnfNodePtr
>>
&
n
)
->
bool
{
return
n
.
first
==
parameter
;
});
if
(
iter
!=
real_inputs_
.
end
())
{
auto
&
args
=
iter
->
second
;
args
.
push_back
(
arg
);
}
else
{
real_inputs_
.
emplace_back
(
parameter
,
std
::
vector
<
AnfNodePtr
>
(
1
,
arg
));
}
}
void
KernelGraph
::
AddUnreuseArgs
(
const
AnfNodePtr
&
arg
,
const
std
::
shared_ptr
<
KernelGraph
>
&
from_graph
)
{
unreuse_args_
[
arg
]
=
from_graph
;
}
void
KernelGraph
::
UpdateCallRealInput
()
{
MS_LOG
(
INFO
)
<<
"Update graph id: "
<<
graph_id_
;
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
std
::
vector
<
AnfNodePtr
>>>
real_inputs_map
;
for
(
auto
&
it
:
real_inputs_
)
{
auto
parameter
=
it
.
first
;
MS_EXCEPTION_IF_NULL
(
parameter
);
auto
real_inputs
=
it
.
second
;
std
::
vector
<
AnfNodePtr
>
new_real_inputs
;
for
(
auto
&
real_input
:
real_inputs
)
{
// if real input is a call node ,find the child graph output act as the new real input
auto
tmp_real_input
=
GetCallRealOutputs
(
real_input
);
std
::
copy
(
tmp_real_input
.
begin
(),
tmp_real_input
.
end
(),
std
::
back_inserter
(
new_real_inputs
));
// replace the call in unreuse_args_
auto
unreuse_arg_it
=
unreuse_args_
.
find
(
real_input
);
if
(
unreuse_arg_it
!=
unreuse_args_
.
end
())
{
auto
old_graph
=
unreuse_arg_it
->
second
;
for
(
auto
new_real_input
:
new_real_inputs
)
{
// if call reference graph output is parameter, it will be allowed to reuse
if
(
!
new_real_input
->
isa
<
Parameter
>
())
{
unreuse_args_
[
new_real_input
]
=
old_graph
;
}
}
}
}
real_inputs_map
.
emplace_back
(
parameter
,
new_real_inputs
);
}
real_inputs_
=
real_inputs_map
;
}
void
KernelGraph
::
PrintGraphExecuteOrder
()
const
{
MS_LOG
(
INFO
)
<<
"Graph:"
<<
graph_id_
<<
"execution order"
;
for
(
size_t
i
=
0
;
i
<
execution_order_
.
size
();
i
++
)
{
...
...
mindspore/ccsrc/backend/session/kernel_graph.h
浏览文件 @
c20cd122
...
...
@@ -131,16 +131,8 @@ class KernelGraph : public FuncGraph {
void
set_parent_graph
(
const
std
::
shared_ptr
<
KernelGraph
>
&
parent_graph
)
{
parent_graph_
=
parent_graph
;
}
// find anf node in graph
std
::
vector
<
CNodePtr
>
FindNodeByPrimitive
(
const
PrimitivePtr
&
primitive
)
const
;
// get real inputs
const
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
std
::
vector
<
AnfNodePtr
>>>
&
real_inputs
()
const
{
return
real_inputs_
;
}
void
SetRealInput
(
const
AnfNodePtr
&
parameter
,
const
AnfNodePtr
&
arg
);
// mark unreused args
void
AddUnreuseArgs
(
const
AnfNodePtr
&
arg
,
const
std
::
shared_ptr
<
KernelGraph
>
&
from_graph
);
const
std
::
map
<
AnfNodePtr
,
std
::
shared_ptr
<
KernelGraph
>>
&
unreuse_args
()
const
{
return
unreuse_args_
;
}
// used to dump ir
std
::
string
ToString
()
const
override
;
// update the real input if the node is a call
void
UpdateCallRealInput
();
void
set_start_label
(
const
CNodePtr
&
start_label
)
{
start_label_
=
start_label
;
}
CNodePtr
get_start_label
()
{
return
start_label_
;
}
...
...
@@ -212,9 +204,6 @@ class KernelGraph : public FuncGraph {
// valid inputs
std
::
vector
<
bool
>
valid_inputs_
;
// new members for control sink process
// all child grahs refers to partial node
std
::
map
<
AnfNodePtr
,
std
::
shared_ptr
<
KernelGraph
>>
node_to_child_graphs_
;
// child graph execute order in root graph
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
child_graph_order_
;
...
...
@@ -223,9 +212,6 @@ class KernelGraph : public FuncGraph {
// parameter graph
std
::
shared_ptr
<
KernelGraph
>
parent_graph_
;
// record real parameters,inputs_ is the formal parameters
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
std
::
vector
<
AnfNodePtr
>>>
real_inputs_
;
std
::
map
<
AnfNodePtr
,
std
::
shared_ptr
<
KernelGraph
>>
unreuse_args_
;
CNodePtr
start_label_
;
CNodePtr
end_goto_
;
...
...
mindspore/ccsrc/backend/session/session_basic.cc
浏览文件 @
c20cd122
...
...
@@ -890,7 +890,7 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
void
SessionBasic
::
Reorder
(
std
::
vector
<
CNodePtr
>
*
node_list
)
{
AnfAlgo
::
ReorderExecList
(
NOT_NULL
(
node_list
));
}
void
SessionBasic
::
G
etSummaryNodes
(
KernelGraph
*
graph
)
{
void
SessionBasic
::
S
etSummaryNodes
(
KernelGraph
*
graph
)
{
MS_LOG
(
DEBUG
)
<<
"Update summary Start"
;
MS_EXCEPTION_IF_NULL
(
graph
);
if
(
!
graph
->
summary_node_exist
())
{
...
...
@@ -930,7 +930,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
if
(
!
exist_summary
)
{
return
;
}
G
etSummaryNodes
(
graph
);
S
etSummaryNodes
(
graph
);
auto
summary_outputs
=
graph
->
summary_nodes
();
std
::
map
<
std
::
string
,
tensor
::
TensorPtr
>
params_list
;
// fetch outputs apply kernel in session & run callback functions
...
...
mindspore/ccsrc/backend/session/session_basic.h
浏览文件 @
c20cd122
...
...
@@ -92,19 +92,9 @@ class SessionBasic {
CNodePtr
HandleSwitchInputs
(
const
AnfNodePtr
&
anf_node
,
KernelGraph
*
graph
);
std
::
vector
<
AnfNodePtr
>
CreateSwitchOrPartialNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
);
// set parameters of final graph
virtual
GraphId
SetFinalGraphInput
(
const
std
::
vector
<
AnfNodePtr
>
&
)
{
return
kInvalidGraphId
;
}
// set output of final graph
virtual
void
SetFinalGraphOutput
(
const
BaseRef
&
)
{}
// insert switch and set the relative active ops
virtual
void
SwitchCompile
(
GraphId
,
GraphId
,
GraphId
,
const
AnfNodePtr
&
)
{}
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
virtual
void
SetChildGraphInput
(
GraphId
,
const
VectorRef
&
)
{}
// get graph id in child graphs by ME front anf node pointer
virtual
GraphId
GetGraphIdByNode
(
const
AnfNodePtr
&
)
const
{
return
kInvalidGraphId
;
}
virtual
GraphId
GetFinalRunGraph
()
const
{
return
kInvalidGraphId
;
}
virtual
void
SetActive
(
GraphId
,
GraphId
)
{}
virtual
void
GetSummaryNodes
(
KernelGraph
*
graph
);
void
AssignParamKey
(
const
KernelGraphPtr
&
kernel_graph
);
void
InitPSParamAndOptim
(
const
KernelGraphPtr
&
kernel_graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs_const
);
virtual
bool
CheckModelInputs
(
uint32_t
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
)
const
{
return
true
;
}
...
...
@@ -120,6 +110,7 @@ class SessionBasic {
#endif
protected:
virtual
void
SetSummaryNodes
(
KernelGraph
*
graph
);
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr
GetGraph
(
GraphId
graph_id
)
const
;
virtual
void
LoadInputData
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
,
...
...
mindspore/ccsrc/pipeline/pynative/CMakeLists.txt
浏览文件 @
c20cd122
file
(
GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"
base.cc"
"
pynative_execute.cc"
)
file
(
GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"pynative_execute.cc"
)
if
(
ENABLE_GE
)
file
(
GLOB_RECURSE _GE_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"pynative_execute_ge.cc"
)
...
...
mindspore/ccsrc/vm/backend.cc
浏览文件 @
c20cd122
...
...
@@ -21,7 +21,6 @@
#include "utils/log_adapter.h"
#include "ir/anf.h"
#include "utils/callbacks.h"
#include "utils/graph_utils.h"
#include "utils/base_ref_extends.h"
#include "backend/session/session_factory.h"
#include "common/utils.h"
...
...
@@ -34,19 +33,6 @@ namespace compile {
bool
Backend
::
GetCond
(
const
BaseRef
&
c
,
bool
*
const
value
)
{
return
BaseRefToBool
(
c
,
value
);
}
bool
Backend
::
GetIndex
(
const
BaseRef
&
c
,
int
*
const
value
)
{
return
BaseRefToInt
(
utils
::
cast
<
ValuePtr
>
(
c
),
value
);
}
LinConvertResult
MsBackend
::
GetMultiGraphRun
(
const
FuncGraphPtr
&
g
)
{
// multi_graph merge to one, big graph have paramters in begin and only have one output
MS_LOG
(
DEBUG
)
<<
"graph:"
<<
g
->
ToString
()
<<
" parameter size:"
<<
g
->
parameters
().
size
();
multi_result_
.
inputs
=
g
->
parameters
();
final_output_
=
NewValueNode
(
"fake_output"
);
multi_result_
.
outputs
=
{
final_output_
};
GraphId
final_g
=
target_sess_
->
GetFinalRunGraph
();
multi_result_
.
run
=
std
::
make_shared
<
RunFunc
>
(
[
final_g
,
this
](
const
VectorRef
&
args
)
->
VectorRef
{
return
MsRunGraph
(
final_g
,
args
,
""
);
});
return
multi_result_
;
}
LinConvertResult
MsBackend
::
MsConvert
(
const
AnfNodePtrList
&
lst
,
const
std
::
string
&
target
)
{
MS_LOG
(
DEBUG
)
<<
"MsConvert"
;
MS_EXCEPTION_IF_NULL
(
MsContext
::
GetInstance
());
...
...
@@ -96,149 +82,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
return
result
;
}
void
MsBackend
::
SetSwitchActive
(
const
BaseRef
&
c
,
bool
cond
)
{
GraphId
active_g
=
simu_cond_map_
[
c
].
cond_graph_map
[
cond
];
GraphId
cond_g
=
kInvalidGraphId
;
if
(
utils
::
isa
<
AnfNodePtr
>
(
c
))
{
cond_g
=
target_sess_
->
GetGraphIdByNode
(
utils
::
cast
<
AnfNodePtr
>
(
c
));
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"cond not a anf node:"
<<
c
.
ToString
();
}
auto
before_cond
=
curr_switch_
;
if
(
curr_switch_
.
hash
()
!=
c
.
hash
())
{
// invoke while false->before true call
if
(
simu_cond_map_
[
before_cond
].
cond_graph_map
.
count
(
false
))
{
active_g
=
simu_cond_map_
[
before_cond
].
cond_graph_map
[
false
];
}
else
{
active_g
=
kInvalidGraphId
;
}
// while x < y:
// z = y + 1
// while z < c2:
// out = out + 1
// z = z + 1
if
(
active_g
==
cond_g
)
{
active_g
=
kInvalidGraphId
;
simu_cond_map_
[
before_cond
].
cond_graph_map
[
false
]
=
kInvalidGraphId
;
}
MS_LOG
(
DEBUG
)
<<
"invoke set active:"
<<
active_g
;
}
MS_LOG
(
DEBUG
)
<<
"switch set active:"
<<
active_g
<<
", "
<<
cond_g
;
target_sess_
->
SetActive
(
active_g
,
cond_g
);
}
void
MsBackend
::
SetSwitchGraph
()
{
MS_LOG
(
DEBUG
)
<<
"SetSwitchGraph curr_switch:"
<<
curr_switch_
.
ToString
();
if
(
is_switch_call_
)
{
GraphId
false_g
=
kInvalidGraphId
;
GraphId
true_g
=
kInvalidGraphId
;
MS_LOG
(
DEBUG
)
<<
"start SetSwitchGraph"
;
true_g
=
simu_cond_map_
[
curr_switch_
].
cond_graph_map
[
true
];
bool
curr_cond
=
simu_cond_map_
[
curr_switch_
].
curr_cond
;
if
(
!
curr_cond
)
{
if
(
simu_cond_map_
[
curr_switch_
].
cond_graph_map
.
count
(
curr_cond
))
{
// has false branch
false_g
=
simu_cond_map_
[
curr_switch_
].
cond_graph_map
[
false
];
}
GraphId
cond_g
=
kInvalidGraphId
;
if
(
utils
::
isa
<
AnfNodePtr
>
(
curr_switch_
))
{
cond_g
=
target_sess_
->
GetGraphIdByNode
(
utils
::
cast
<
AnfNodePtr
>
(
curr_switch_
));
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"cond not a anf node:"
<<
curr_switch_
.
ToString
();
}
MS_LOG
(
DEBUG
)
<<
"switch compile:"
<<
cond_g
<<
", "
<<
true_g
<<
", "
<<
false_g
;
target_sess_
->
SwitchCompile
(
cond_g
,
true_g
,
false_g
,
utils
::
cast
<
AnfNodePtr
>
(
curr_switch_
));
}
is_switch_call_
=
false
;
MS_LOG
(
DEBUG
)
<<
"end SetSwitchGraph:"
<<
curr_cond
<<
", "
<<
is_switch_call_
;
}
}
// convert node from formal parameter to actual parameter,
// and actual parameter is graph user's formal parameter.
// get top while graph's parameter in recall while.
AnfNodePtr
MsBackend
::
ConvertGraphInput
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
)
{
std
::
unordered_map
<
AnfNodePtr
,
size_t
>
params_index
;
auto
result
=
node
;
auto
graph
=
result
->
func_graph
();
while
(
func_graph
!=
graph
)
{
auto
iter
=
graph_user_inputs_
.
find
(
graph
);
if
(
iter
==
graph_user_inputs_
.
end
())
{
break
;
}
params_index
.
clear
();
auto
&
params
=
graph
->
parameters
();
for
(
size_t
i
=
0
;
i
<
params
.
size
();
++
i
)
{
params_index
[
params
[
i
]]
=
i
;
}
graph
=
iter
->
second
.
first
;
auto
&
inputs
=
iter
->
second
.
second
;
result
=
inputs
[
params_index
[
result
]];
}
return
result
;
}
void
MsBackend
::
SetGraphUserInputs
(
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
user
,
const
AnfNodePtrList
&
inputs
)
{
if
(
graph_user_inputs_
.
find
(
func_graph
)
!=
graph_user_inputs_
.
end
())
{
return
;
}
graph_user_inputs_
[
func_graph
]
=
{
user
,
inputs
};
}
void
MsBackend
::
RecallGraphInput
(
const
FuncGraphPtr
&
func_graph
,
const
VectorRef
&
args
,
const
BaseRef
&
c
)
{
std
::
unordered_map
<
AnfNodePtr
,
size_t
>
params_index
;
auto
&
params
=
func_graph
->
parameters
();
for
(
size_t
i
=
0
;
i
<
params
.
size
();
++
i
)
{
params_index
[
params
[
i
]]
=
i
;
}
// recall all child graphs in this while
auto
&
graph_inputs
=
graph_inputs_
[
c
];
for
(
auto
&
iter
:
graph_inputs
)
{
auto
&
graph
=
iter
.
first
;
auto
&
old_args
=
iter
.
second
;
auto
&
result
=
graph_id_map_
[
graph
];
auto
&
inputs
=
result
.
inputs
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
input
=
ConvertGraphInput
(
func_graph
,
inputs
[
i
]);
auto
it
=
params_index
.
find
(
input
);
if
(
it
!=
params_index
.
end
())
{
old_args
[
i
]
=
args
[
it
->
second
];
}
}
target_sess_
->
SetChildGraphInput
(
graph
,
old_args
);
}
graph_inputs_
.
erase
(
c
);
}
// compile set input output
VectorRef
MsBackend
::
MsSimuRunGraph
(
const
GraphId
&
g
,
const
VectorRef
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"set graph input:"
<<
g
;
// switch maybe twice
target_sess_
->
SetChildGraphInput
(
g
,
args
);
if
(
is_switch_call_
)
{
if
(
!
curr_switch_
.
is_null
())
{
// push this {g, args} to all user while graph_inputs for nest while,
// when current condition recall over delete this cond in graph_inputs.
for
(
auto
&
iter
:
graph_inputs_
)
{
iter
.
second
.
push_back
({
g
,
args
});
}
if
(
graph_inputs_
.
find
(
curr_switch_
)
==
graph_inputs_
.
end
())
{
graph_inputs_
[
curr_switch_
].
push_back
({
g
,
args
});
}
}
bool
curr_cond
=
simu_cond_map_
[
curr_switch_
].
curr_cond
;
MS_LOG
(
DEBUG
)
<<
"switch call MsSimuRunGraph:"
<<
curr_cond
<<
", "
<<
g
;
simu_cond_map_
[
curr_switch_
].
cond_graph_map
[
curr_cond
]
=
g
;
SetSwitchGraph
();
}
std
::
vector
<
BaseRef
>
outputs
;
(
void
)
std
::
transform
(
graph_id_map_
[
g
].
outputs
.
begin
(),
graph_id_map_
[
g
].
outputs
.
end
(),
std
::
back_inserter
(
outputs
),
[](
const
AnfNodePtr
&
v
)
{
return
v
;
});
...
...
@@ -290,36 +136,6 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
return
outputs
;
}
SwitchCondStatus
MsBackend
::
SetSimuCond
(
const
BaseRef
&
c
,
bool
value
)
{
MS_LOG
(
DEBUG
)
<<
"set cond :"
<<
c
.
ToString
()
<<
", "
<<
simu_cond_map_
.
size
();
CondGraph
cond_graph
;
cond_graph
.
curr_cond
=
value
;
if
(
simu_cond_map_
.
find
(
c
)
==
simu_cond_map_
.
end
())
{
simu_cond_map_
[
c
]
=
cond_graph
;
}
if
(
simu_cond_map_
[
c
].
cond_graph_map
.
count
(
value
))
{
return
kCondAlreadyRun
;
}
simu_cond_map_
[
c
].
curr_cond
=
value
;
MS_LOG
(
DEBUG
)
<<
"end set cond "
;
return
kCondOk
;
}
void
MsBackend
::
SimulateRun
(
FinalVMPtr
rt
,
FuncGraphPtr
root
)
{
MS_LOG
(
DEBUG
)
<<
"Simulate run,root:"
<<
root
->
ToString
()
<<
", "
<<
root
->
parameters
().
size
();
std
::
vector
<
BaseRef
>
args
;
auto
parameters
=
root
->
parameters
();
(
void
)
std
::
transform
(
parameters
.
begin
(),
parameters
.
end
(),
std
::
back_inserter
(
args
),
[](
const
AnfNodePtr
&
v
)
{
return
v
;
});
MS_LOG
(
DEBUG
)
<<
"Simulate start"
;
(
void
)
target_sess_
->
SetFinalGraphInput
(
parameters
);
BaseRef
output
=
rt
->
Eval
(
VectorRef
(
args
));
target_sess_
->
SetFinalGraphOutput
(
output
);
MS_LOG
(
DEBUG
)
<<
"Simulate Eval end"
;
}
void
MsBackend
::
Link
(
GraphId
graph_id
)
{
if
(
graph_id
==
kInvalidGraphId
)
{
graph_id
=
target_sess_
->
GetFinalRunGraph
();
...
...
@@ -330,9 +146,6 @@ void MsBackend::Link(GraphId graph_id) {
Backend
::
Backend
(
const
std
::
string
&
name
)
:
name_
(
name
)
{
MS_LOG
(
DEBUG
)
<<
"select backend:"
<<
name
;
convert_fn_
=
backends
[
name_
];
is_switch_call_
=
false
;
is_multi_graph_sink_
=
false
;
simu_flag_
=
false
;
}
MsBackend
::
MsBackend
(
const
std
::
string
&
name
,
const
std
::
string
&
target
,
uint32_t
device_id
)
:
Backend
(
name
)
{
...
...
mindspore/ccsrc/vm/backend.h
浏览文件 @
c20cd122
...
...
@@ -43,50 +43,19 @@ class Backend {
LinkFuncType
convert_fn
()
{
return
convert_fn_
;
}
std
::
string
name
()
{
return
name_
;
}
virtual
void
SimulateRun
(
FinalVMPtr
,
FuncGraphPtr
)
{}
virtual
SwitchCondStatus
SetSimuCond
(
const
BaseRef
&
,
bool
)
{
return
kCondOk
;
}
virtual
bool
GetCond
(
const
BaseRef
&
c
,
bool
*
value
);
virtual
bool
GetIndex
(
const
BaseRef
&
c
,
int
*
value
);
virtual
void
SetSwitchGraph
()
{}
virtual
void
SetSwitchActive
(
const
BaseRef
&
,
bool
)
{}
virtual
void
RecallGraphInput
(
const
FuncGraphPtr
&
,
const
VectorRef
&
,
const
BaseRef
&
)
{}
virtual
void
SetGraphUserInputs
(
const
FuncGraphPtr
&
,
const
FuncGraphPtr
&
,
const
AnfNodePtrList
&
)
{}
virtual
GraphId
CompileGraph
(
NotNull
<
FuncGraphPtr
>
fg
)
{
return
kInvalidGraphId
;
}
void
set_curr_switch
(
const
BaseRef
&
value
)
{
curr_switch_
=
value
;
is_switch_call_
=
true
;
}
BaseRef
curr_switch
()
{
return
curr_switch_
;
}
virtual
void
Link
(
GraphId
)
{}
virtual
LinConvertResult
GetMultiGraphRun
(
const
FuncGraphPtr
&
)
{
return
LinConvertResult
();
}
virtual
void
SetDebugger
()
{
}
LinConvertResult
multi_result
()
{
return
multi_result_
;
}
void
set_multi_result
(
const
LinConvertResult
&
value
)
{
multi_result_
=
value
;
}
AnfNodePtr
final_output
()
const
{
return
final_output_
;
}
bool
is_multi_graph_sink
()
const
{
return
is_multi_graph_sink_
;
}
void
set_is_multi_graph_sink
(
bool
flag
)
{
is_multi_graph_sink_
=
flag
;
}
bool
simu_flag
()
const
{
return
simu_flag_
;
}
bool
is_switch_call
()
const
{
return
is_switch_call_
;
}
void
set_simu_flag
(
bool
simu
)
{
simu_flag_
=
simu
;
}
virtual
void
SetDebugger
()
{}
protected:
std
::
string
name_
;
LinkFuncType
convert_fn_
;
BaseRef
curr_switch_
;
// curr switch node
bool
is_multi_graph_sink_
;
bool
is_switch_call_
;
bool
simu_flag_
;
LinConvertResult
multi_result_
;
AnfNodePtr
final_output_
;
std
::
unordered_map
<
FuncGraphPtr
,
std
::
pair
<
FuncGraphPtr
,
AnfNodePtrList
>>
graph_user_inputs_
;
};
struct
CondGraph
{
bool
curr_cond
;
std
::
unordered_map
<
bool
,
GraphId
>
cond_graph_map
;
};
class
MsBackend
:
public
Backend
{
...
...
@@ -98,16 +67,7 @@ class MsBackend : public Backend {
VectorRef
MsRunGraph
(
const
GraphId
&
g
,
const
VectorRef
&
args
,
const
std
::
string
&
target
=
""
);
VectorRef
MsSimuRunGraph
(
const
GraphId
&
g
,
const
VectorRef
&
args
);
void
SimulateRun
(
FinalVMPtr
rt
,
FuncGraphPtr
root
)
override
;
SwitchCondStatus
SetSimuCond
(
const
BaseRef
&
c
,
bool
value
)
override
;
void
SetSwitchGraph
()
override
;
void
SetSwitchActive
(
const
BaseRef
&
c
,
bool
cond
)
override
;
void
RecallGraphInput
(
const
FuncGraphPtr
&
,
const
VectorRef
&
,
const
BaseRef
&
)
override
;
void
SetGraphUserInputs
(
const
FuncGraphPtr
&
,
const
FuncGraphPtr
&
,
const
AnfNodePtrList
&
)
override
;
void
Link
(
GraphId
)
override
;
AnfNodePtr
ConvertGraphInput
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
);
LinConvertResult
GetMultiGraphRun
(
const
FuncGraphPtr
&
g
)
override
;
GraphId
CompileGraph
(
NotNull
<
FuncGraphPtr
>
fg
)
override
;
VectorRef
RunGraph
(
GraphId
graph_id
,
const
VectorRef
&
args
);
void
CreateOtherSession
(
const
std
::
string
&
target
);
...
...
@@ -121,9 +81,7 @@ class MsBackend : public Backend {
session
::
SessionPtr
other_sess_
;
std
::
string
target_device_
;
std
::
string
other_device_
;
std
::
unordered_map
<
BaseRef
,
CondGraph
,
BaseRefHash
>
simu_cond_map_
;
std
::
unordered_map
<
GraphId
,
LinConvertResult
>
graph_id_map_
;
std
::
unordered_map
<
BaseRef
,
std
::
list
<
std
::
pair
<
GraphId
,
VectorRef
>>
,
BaseRefHash
>
graph_inputs_
;
};
}
// namespace compile
}
// namespace mindspore
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
c20cd122
...
...
@@ -515,11 +515,7 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no
MS_LOG
(
DEBUG
)
<<
"LinConvert start"
;
LinConvertResult
result
;
if
(
backend_
->
simu_flag
())
{
result
=
backend_
->
GetMultiGraphRun
(
graph
);
}
else
{
result
=
lin_convert_
(
node_list
,
target
);
}
result
=
lin_convert_
(
node_list
,
target
);
if
(
result
.
run
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"LinConvert failed"
;
...
...
@@ -546,27 +542,6 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no
return
RET_SUCCESS
;
}
void
CompileGraph
::
AddSinkSwitch
(
const
CNodePtr
&
node
)
{
MS_LOG
(
DEBUG
)
<<
"AddSinkSwitch:"
<<
node
->
ToString
();
if
(
backend_
->
is_multi_graph_sink
())
{
VectorRef
args
;
args
.
emplace_back
(
-
1
);
MS_LOG
(
DEBUG
)
<<
"call::"
<<
height_
;
AddInst
(
Instruction
::
kCall
,
args
);
args
.
clear
();
args
.
emplace_back
(
node
->
input
(
1
));
AddInst
(
Instruction
::
kSwitchReturn
,
args
);
args
.
clear
();
args
.
emplace_back
(
false
);
args
.
emplace_back
(
Ref
(
node
->
input
(
1
)));
args
.
emplace_back
(
Ref
(
node
->
input
(
2
)));
args
.
emplace_back
(
Ref
(
node
->
input
(
3
)));
AddInst
(
Instruction
::
kSwitch
,
args
);
}
}
int
CompileGraph
::
InterpretNode
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_LOG
(
DEBUG
)
<<
"Interpret node: "
<<
node
->
DebugString
(
true
);
...
...
@@ -589,7 +564,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
AddPartial
(
node
);
}
else
if
(
IsPrimitive
(
fn
,
prim
::
kPrimSwitch
))
{
AddSwitch
(
node
);
AddSinkSwitch
(
node
);
}
else
if
(
IsPrimitive
(
fn
,
prim
::
kPrimSwitchLayer
))
{
AddSwitchLayer
(
node
);
}
else
if
(
IsPrimitive
(
fn
,
prim
::
kPrimMakeTuple
))
{
...
...
@@ -607,14 +581,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
return
RET_SUCCESS
;
}
void
CompileGraph
::
GenMultiGraphsRun
(
const
FuncGraphPtr
&
graph
)
{
auto
ret
=
LinConvert
(
graph
,
{});
if
(
ret
==
RET_FAILED
)
{
MS_LOG
(
EXCEPTION
)
<<
"MultiGraphRun failed."
;
}
AddReturn
(
nullptr
);
}
bool
CompileGraph
::
SplitGraph
(
const
FuncGraphPtr
&
graph
)
{
MS_LOG
(
DEBUG
)
<<
"Start split graph"
;
MS_EXCEPTION_IF_NULL
(
graph
);
...
...
@@ -659,11 +625,6 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
return
true
;
}
InstSet
CompileGraph
::
GenMultiGraphsSinkInst
(
const
FuncGraphPtr
&
graph
)
{
InstSet
inst
=
Run
(
graph
);
return
inst
;
}
InstSet
CompileGraph
::
Run
(
const
FuncGraphPtr
&
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
...
...
@@ -672,12 +633,8 @@ InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
int
param_height
=
height_
;
MS_LOG
(
DEBUG
)
<<
"'param_height': "
<<
height_
<<
" to split graph: "
<<
graph
->
get_return
()
->
DebugString
(
true
);
if
(
backend_
->
simu_flag
())
{
GenMultiGraphsRun
(
graph
);
}
else
{
if
(
!
SplitGraph
(
graph
))
{
return
inst_
;
}
if
(
!
SplitGraph
(
graph
))
{
return
inst_
;
}
AddPadStack
(
param_height
);
...
...
@@ -712,12 +669,6 @@ void CompileGraph::AddPartial(const CNodePtr &node) {
if
(
!
IsValueNode
<
FuncGraph
>
(
fn
))
{
MS_LOG
(
EXCEPTION
)
<<
"The type of 1st input of node must be FuncGraph"
;
}
if
(
backend_
->
is_multi_graph_sink
())
{
auto
func_graph
=
GetValueNode
<
FuncGraphPtr
>
(
fn
);
args
.
emplace_back
(
func_graph
);
AnfNodePtrList
outs
(
inputs
.
begin
()
+
2
,
inputs
.
end
());
backend_
->
SetGraphUserInputs
(
func_graph
,
node
->
func_graph
(),
outs
);
}
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
i
++
)
{
args
.
emplace_back
(
Ref
(
inputs
[
i
]));
}
...
...
@@ -739,9 +690,6 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
MS_LOG
(
EXCEPTION
)
<<
"Length of inputs of primitive "
<<
prim
::
kPrimSwitch
->
name
()
<<
" is less than 4"
;
}
VectorRef
args
;
if
(
backend_
->
is_multi_graph_sink
())
{
args
.
emplace_back
(
true
);
}
args
.
emplace_back
(
Ref
(
inputs
[
1
]));
args
.
emplace_back
(
Ref
(
inputs
[
2
]));
args
.
emplace_back
(
Ref
(
inputs
[
3
]));
...
...
@@ -761,11 +709,7 @@ void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
void
CompileGraph
::
AddReturn
(
const
CNodePtr
&
node
)
{
VectorRef
args
;
if
(
backend_
->
simu_flag
())
{
args
.
emplace_back
(
Ref
(
backend_
->
final_output
()));
}
else
{
args
.
emplace_back
(
Ref
(
node
->
input
(
1
)));
}
args
.
emplace_back
(
Ref
(
node
->
input
(
1
)));
args
.
emplace_back
(
height_
);
AddInst
(
Instruction
::
kReturn
,
args
);
}
...
...
@@ -783,11 +727,6 @@ void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim)
int
CompileGraph
::
AddCall
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
node
)
{
auto
inputs
=
node
->
inputs
();
AnfNodePtr
fn
=
inputs
[
0
];
if
(
backend_
->
is_multi_graph_sink
()
&&
IsValueNode
<
FuncGraph
>
(
fn
))
{
auto
func_graph
=
GetValueNode
<
FuncGraphPtr
>
(
fn
);
AnfNodePtrList
outs
(
inputs
.
begin
()
+
1
,
inputs
.
end
());
backend_
->
SetGraphUserInputs
(
func_graph
,
node
->
func_graph
(),
outs
);
}
(
void
)
Ref
(
fn
);
size_t
size
=
inputs
.
size
();
for
(
size_t
i
=
size
-
1
;
i
>
0
;
i
--
)
{
...
...
@@ -929,17 +868,6 @@ FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) {
}
FinalVMPtr
rt
=
std
::
make_shared
<
FinalVM
>
(
insts_
,
backend_
);
if
(
backend_
->
is_multi_graph_sink
())
{
backend_
->
set_simu_flag
(
true
);
MS_LOG
(
DEBUG
)
<<
"Start simulate"
;
backend_
->
SimulateRun
(
rt
,
graph
);
MS_LOG
(
DEBUG
)
<<
"Link graphs"
;
insts_
=
transform_
->
GenMultiGraphsSinkInst
(
graph
);
rt
->
set_insts
(
insts_
);
backend_
->
set_simu_flag
(
false
);
MS_LOG
(
DEBUG
)
<<
"End start simulate"
;
backend_
->
Link
(
kInvalidGraphId
);
}
MS_LOG
(
DEBUG
)
<<
"End"
;
return
rt
;
}
...
...
mindspore/ccsrc/vm/transform.h
浏览文件 @
c20cd122
...
...
@@ -54,12 +54,10 @@ class CompileGraph {
~
CompileGraph
()
=
default
;
InstSet
Run
(
const
FuncGraphPtr
&
func_graph
);
InstSet
GenMultiGraphsSinkInst
(
const
FuncGraphPtr
&
graph
);
bool
IsCut
(
const
AnfNodePtr
&
node
);
void
Push
(
const
AnfNodePtr
&
node
);
void
Tie
(
const
AnfNodePtr
&
n1
,
const
AnfNodePtr
&
n2
)
{
slots_
[
n2
]
=
slots_
[
n1
];
}
void
Ret
(
int
nargs
);
void
GenMultiGraphsRun
(
const
FuncGraphPtr
&
graph
);
int
Ref
(
const
AnfNodePtr
&
node
);
VectorRef
SplitNodes
(
const
FuncGraphPtr
&
func_graph
);
...
...
@@ -84,7 +82,6 @@ class CompileGraph {
int
LinConvert
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtrList
&
node_list
,
const
std
::
string
&
target
=
""
);
int
InterpretNode
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
);
int
AddCall
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
node
);
void
AddSinkSwitch
(
const
CNodePtr
&
node
);
void
AddPadStack
(
int
param_height
);
void
AddTailCall
(
const
AnfNodePtr
&
fn
,
size_t
size
);
void
AddPartial
(
const
CNodePtr
&
node
);
...
...
mindspore/ccsrc/vm/vm.cc
浏览文件 @
c20cd122
...
...
@@ -17,12 +17,9 @@
*/
#include "vm/vm.h"
#include <algorithm>
#include "vm/vmimpl.h"
#include "vm/backend.h"
#include "vm/transform.h"
#include "pipeline/jit/parse/data_converter.h"
#include "utils/base_ref_extends.h"
...
...
@@ -142,33 +139,10 @@ void FinalVM::Popsp() {
}
}
void
FinalVM
::
PushStatus
(
bool
is_switch_call
)
{
ret_status_
.
push
(
is_switch_call
);
}
bool
FinalVM
::
PopStatus
()
{
if
(
ret_status_
.
empty
())
{
return
false
;
}
bool
status
=
ret_status_
.
top
();
ret_status_
.
pop
();
return
status
;
}
void
FinalVM
::
DoJmp
(
const
BaseRef
&
jmp_orig
)
{
MS_LOG
(
DEBUG
)
<<
"Start"
;
BaseRef
jmp
=
jmp_orig
;
if
(
backend_
->
simu_flag
())
{
bool
is_switch_call
=
false
;
if
(
utils
::
isa
<
StructSimuSwitch
>
(
jmp
))
{
// need to inherit from Base
MS_LOG
(
DEBUG
)
<<
"Start jump StructSwitch"
;
auto
simu_value
=
utils
::
cast
<
std
::
shared_ptr
<
StructSimuSwitch
>>
(
jmp
);
jmp
=
simu_value
->
fn_
;
backend_
->
set_curr_switch
(
simu_value
->
value_
);
is_switch_call
=
true
;
}
PushStatus
(
is_switch_call
);
}
if
(
utils
::
isa
<
StructPartial
>
(
jmp
))
{
// need to inherit from Base
MS_LOG
(
DEBUG
)
<<
"Start jump StructPartial"
;
auto
new_jmp
=
utils
::
cast
<
std
::
shared_ptr
<
StructPartial
>>
(
jmp
);
...
...
@@ -270,13 +244,6 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) {
MS_LOG
(
ERROR
)
<<
__FUNCTION__
<<
" requires one parameter, while the input size is "
<<
args
.
size
()
<<
"."
;
return
;
}
auto
rv
=
Ref
(
-
1
);
if
(
utils
::
isa
<
AnfNodePtr
>
(
rv
)
||
utils
::
isa
<
VectorRef
>
(
rv
))
{
auto
&
c
=
args
[
0
];
cond_out_
[
c
]
=
rv
;
}
Pop
(
1
);
Popsp
();
}
...
...
@@ -294,51 +261,12 @@ void FinalVM::InstReturn(const VectorRef &args) {
int
height
=
utils
::
cast
<
int
>
(
args
[
1
]);
auto
rv
=
Ref
(
rpos
);
if
(
backend_
->
simu_flag
())
{
auto
c
=
backend_
->
curr_switch
();
auto
status
=
PopStatus
();
if
(
status
)
{
auto
iter
=
cond_out_
.
find
(
c
);
if
(
iter
!=
cond_out_
.
end
())
{
rv
=
MergeArgs
(
rv
,
iter
->
second
);
cond_out_
.
erase
(
iter
);
}
}
if
(
backend_
->
is_switch_call
())
{
backend_
->
SetSwitchGraph
();
}
}
Pop
(
height
);
Push
(
rv
);
Popp
();
MS_LOG
(
DEBUG
)
<<
"End"
;
}
void
FinalVM
::
InstSimuPartial
(
const
VectorRef
&
args
)
{
const
size_t
args_size
=
2
;
if
(
args
.
size
()
<
args_size
)
{
MS_LOG
(
ERROR
)
<<
__FUNCTION__
<<
" requires "
<<
args_size
<<
" or more parameters, while the input size is "
<<
args
.
size
()
<<
"."
;
return
;
}
auto
&
node
=
args
[
0
];
if
(
!
utils
::
isa
<
FuncGraphPtr
>
(
node
))
{
MS_LOG
(
ERROR
)
<<
"The type of 1st input of node must be FuncGraph"
;
return
;
}
auto
fg
=
utils
::
cast
<
FuncGraphPtr
>
(
node
);
int
fn_
=
utils
::
cast
<
int
>
(
args
[
1
]);
auto
fn
=
utils
::
cast
<
int
>
(
Ref
(
fn_
));
MS_LOG
(
DEBUG
)
<<
"Partial argssize:"
<<
args
.
size
();
std
::
vector
<
BaseRef
>
outs
(
args
.
size
()
-
2
);
(
void
)
std
::
transform
(
args
.
begin
()
+
2
,
args
.
end
(),
outs
.
begin
(),
[
&
,
this
](
const
BaseRef
&
a
)
{
return
Ref
(
utils
::
cast
<
int
>
(
a
));
});
Push
(
std
::
make_shared
<
StructPartial
>
(
fn
,
VectorRef
(
outs
),
fg
));
}
void
FinalVM
::
InstRealPartial
(
const
VectorRef
&
args
)
{
const
size_t
args_size
=
1
;
if
(
args
.
size
()
<
args_size
)
{
...
...
@@ -358,91 +286,10 @@ void FinalVM::InstRealPartial(const VectorRef &args) {
void
FinalVM
::
InstPartial
(
const
VectorRef
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"Start"
;
if
(
backend_
->
is_multi_graph_sink
())
{
InstSimuPartial
(
args
);
}
else
{
InstRealPartial
(
args
);
}
InstRealPartial
(
args
);
MS_LOG
(
DEBUG
)
<<
"End"
;
}
void
FinalVM
::
InstSimuSwitch
(
const
VectorRef
&
args
)
{
const
size_t
args_size
=
4
;
if
(
args
.
size
()
!=
args_size
)
{
MS_LOG
(
ERROR
)
<<
__FUNCTION__
<<
" requires "
<<
args_size
<<
" parameters, while the input size is "
<<
args
.
size
()
<<
"."
;
return
;
}
bool
cond
=
utils
::
cast
<
bool
>
(
args
[
0
]);
int
cond_node
=
utils
::
cast
<
int
>
(
args
[
1
]);
int
vtrue
=
utils
::
cast
<
int
>
(
args
[
2
]);
int
vfalse
=
utils
::
cast
<
int
>
(
args
[
3
]);
MS_LOG
(
DEBUG
)
<<
"Simu switch cond:"
<<
cond
;
BaseRef
c
=
Ref
(
cond_node
);
bool
bool_value
=
cond
;
SwitchCondStatus
cond_stat
=
backend_
->
SetSimuCond
(
c
,
bool_value
);
if
(
cond_stat
==
kCondAlreadyRun
)
{
MS_LOG
(
DEBUG
)
<<
"switch alreay run bool while true jmp"
;
BaseRef
jmp
=
Ref
(
vtrue
);
if
(
utils
::
isa
<
StructPartial
>
(
jmp
))
{
auto
new_jmp
=
utils
::
cast
<
std
::
shared_ptr
<
StructPartial
>>
(
jmp
);
backend_
->
RecallGraphInput
(
new_jmp
->
fg_
,
new_jmp
->
args_
,
c
);
}
cond_jmp_
[
c
]
=
Ref
(
vfalse
);
Push
(
static_cast
<
int
>
(
cond_stat
));
Popp
();
backend_
->
SetSwitchActive
(
c
,
bool_value
);
return
;
}
if
(
bool_value
)
{
Push
(
std
::
make_shared
<
StructSimuSwitch
>
(
Ref
(
vtrue
),
c
));
Pushsp
();
}
else
{
MergeJmpArgs
(
Ref
(
vfalse
),
c
);
Push
(
std
::
make_shared
<
StructSimuSwitch
>
(
Ref
(
vfalse
),
c
));
}
}
void
FinalVM
::
MergeJmpArgs
(
const
BaseRef
&
jmp
,
const
BaseRef
&
c
)
{
auto
iter
=
cond_jmp_
.
find
(
c
);
if
(
iter
==
cond_jmp_
.
end
())
{
return
;
}
auto
old_jmp
=
utils
::
cast
<
std
::
shared_ptr
<
StructPartial
>>
(
iter
->
second
);
auto
new_jmp
=
utils
::
cast
<
std
::
shared_ptr
<
StructPartial
>>
(
jmp
);
auto
&
old_args
=
old_jmp
->
args_
;
auto
&
new_args
=
new_jmp
->
args_
;
for
(
size_t
i
=
0
;
i
<
new_args
.
size
();
++
i
)
{
auto
&
old_arg
=
old_args
[
i
];
auto
&
new_arg
=
new_args
[
i
];
new_arg
=
MergeArgs
(
old_arg
,
new_arg
);
}
}
BaseRef
FinalVM
::
MergeArgs
(
const
BaseRef
&
first
,
const
BaseRef
&
second
)
{
MS_LOG
(
DEBUG
)
<<
__FUNCTION__
<<
": "
<<
first
.
ToString
()
<<
", "
<<
second
.
ToString
();
if
(
utils
::
isa
<
VectorRef
>
(
first
))
{
auto
old_vec_ref
=
utils
::
cast
<
VectorRef
>
(
first
);
if
(
utils
::
isa
<
VectorRef
>
(
second
))
{
auto
new_vec_ref
=
utils
::
cast
<
VectorRef
>
(
second
);
std
::
copy
(
new_vec_ref
.
begin
(),
new_vec_ref
.
end
(),
std
::
back_inserter
(
old_vec_ref
));
}
else
{
old_vec_ref
.
push_back
(
second
);
}
return
old_vec_ref
;
}
if
(
utils
::
isa
<
VectorRef
>
(
second
))
{
auto
new_vec_ref
=
utils
::
cast
<
VectorRef
>
(
second
);
new_vec_ref
.
push_back
(
first
);
return
new_vec_ref
;
}
return
VectorRef
({
first
,
second
});
}
void
FinalVM
::
InstRealSwitch
(
const
VectorRef
&
args
)
{
const
size_t
args_size
=
3
;
if
(
args
.
size
()
!=
args_size
)
{
...
...
@@ -472,11 +319,7 @@ void FinalVM::InstRealSwitch(const VectorRef &args) {
void
FinalVM
::
InstSwitch
(
const
VectorRef
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"Start"
;
if
(
backend_
->
is_multi_graph_sink
())
{
InstSimuSwitch
(
args
);
}
else
{
InstRealSwitch
(
args
);
}
InstRealSwitch
(
args
);
MS_LOG
(
DEBUG
)
<<
"End"
;
}
...
...
@@ -580,14 +423,6 @@ void FinalVM::InstExternal(const VectorRef &args) {
VectorRef
tuple
;
RunFunctionRef
run_ref
=
utils
::
cast
<
RunFunctionRef
>
(
args
[
0
]);
compile
::
RunFuncPtr
fn
=
run_ref
.
func_
;
if
(
backend_
->
simu_flag
())
{
MS_LOG
(
DEBUG
)
<<
"Simu run"
;
if
(
args
.
size
()
==
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"The number of args should be greater than 1, but got 1"
;
}
auto
simu_run_ref
=
utils
::
cast
<
RunFunctionRef
>
(
args
[
1
]);
fn
=
simu_run_ref
.
func_
;
}
for
(
size_t
i
=
2
;
i
<
args
.
size
();
++
i
)
{
auto
index
=
utils
::
cast
<
int
>
(
args
[
i
]);
tuple
.
push_back
(
Ref
(
index
));
...
...
mindspore/ccsrc/vm/vm.h
浏览文件 @
c20cd122
...
...
@@ -96,7 +96,6 @@ class FinalVM {
public:
// Create a VM with the specified instructions and backend.
explicit
FinalVM
(
const
InstSet
&
insts
,
const
BackendPtr
&
backend
);
virtual
~
FinalVM
()
=
default
;
BaseRef
Eval
(
const
VectorRef
&
args
);
...
...
@@ -104,10 +103,8 @@ class FinalVM {
void
InstTailCall
(
const
VectorRef
&
args
);
void
InstReturn
(
const
VectorRef
&
args
);
void
InstPartial
(
const
VectorRef
&
args
);
void
InstSimuPartial
(
const
VectorRef
&
args
);
void
InstRealPartial
(
const
VectorRef
&
args
);
void
InstSwitch
(
const
VectorRef
&
args
);
void
InstSimuSwitch
(
const
VectorRef
&
args
);
void
InstRealSwitch
(
const
VectorRef
&
args
);
void
InstTuple
(
const
VectorRef
&
args
);
void
InstPush
(
const
VectorRef
&
args
);
...
...
@@ -129,23 +126,16 @@ class FinalVM {
void
Popp
();
void
Pushsp
();
void
Popsp
();
void
PushStatus
(
bool
is_switch_call
);
bool
PopStatus
();
void
DoJmp
(
const
BaseRef
&
jmp
);
void
SyncData
(
const
py
::
object
&
args
);
void
MergeJmpArgs
(
const
BaseRef
&
jmp
,
const
BaseRef
&
c
);
BaseRef
MergeArgs
(
const
BaseRef
&
first
,
const
BaseRef
&
second
);
private:
InstSet
insts_
;
std
::
deque
<
BaseRef
>
insts_stack_
;
std
::
stack
<
int
>
retp_
;
std
::
stack
<
int
>
retsp_
;
std
::
stack
<
bool
>
ret_status_
;
int
pc_
;
int
sp_
;
std
::
unordered_map
<
BaseRef
,
BaseRef
,
BaseRefHash
>
cond_jmp_
;
std
::
unordered_map
<
BaseRef
,
BaseRef
,
BaseRefHash
>
cond_out_
;
BackendPtr
backend_
;
const
InstFunctionMap
inst_function_map
=
{
{
Instruction
::
kCall
,
[
this
](
const
VectorRef
&
args
)
{
InstCall
(
args
);
}},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录