Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6ae8345c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6ae8345c
编写于
4月 26, 2020
作者:
R
rick_sanchez
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor vm module for multigraph sink
上级
1b5fb395
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
561 addition
and
108 deletion
+561
-108
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+56
-24
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+4
-0
mindspore/ccsrc/session/kernel_graph.h
mindspore/ccsrc/session/kernel_graph.h
+1
-1
mindspore/ccsrc/utils/base_ref.h
mindspore/ccsrc/utils/base_ref.h
+9
-0
mindspore/ccsrc/vm/backend.cc
mindspore/ccsrc/vm/backend.cc
+74
-32
mindspore/ccsrc/vm/backend.h
mindspore/ccsrc/vm/backend.h
+10
-6
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+19
-4
mindspore/ccsrc/vm/vm.cc
mindspore/ccsrc/vm/vm.cc
+76
-40
mindspore/ccsrc/vm/vm.h
mindspore/ccsrc/vm/vm.h
+9
-1
tests/st/control/test_multigraph_sink.py
tests/st/control/test_multigraph_sink.py
+184
-0
tests/ut/python/pynative_mode/test_multigraph_sink.py
tests/ut/python/pynative_mode/test_multigraph_sink.py
+119
-0
未找到文件。
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
6ae8345c
...
...
@@ -800,45 +800,77 @@ void AscendSession::UpdateGraphOrder(GraphId to_graph_id) {
}
}
size_t
AscendSession
::
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
size_t
input_index
)
{
auto
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
node
);
if
(
output_num
>
1
&&
!
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimTupleGetItem
))
{
return
input_index
+
output_num
;
}
auto
&
graph_inputs
=
graph
->
inputs
();
auto
&
valid_inputs
=
graph
->
ValidInputs
();
if
(
valid_inputs
[
input_index
])
{
SetChildGraphParameter
(
node
,
graph_inputs
[
input_index
]);
}
else
{
MS_LOG
(
DEBUG
)
<<
"Invalid input arg: "
<<
node
->
DebugString
();
}
return
++
input_index
;
}
size_t
AscendSession
::
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
ValuePtr
&
value
,
size_t
input_index
)
{
MS_EXCEPTION_IF_NULL
(
value
);
if
(
!
value
->
isa
<
Tensor
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Value Node should be a tensor, unexpected value: "
<<
value
->
ToString
();
}
auto
&
graph_inputs
=
graph
->
inputs
();
SetChildGraphParameter
(
value
->
cast
<
TensorPtr
>
(),
graph_inputs
[
input_index
]);
return
++
input_index
;
}
size_t
AscendSession
::
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
VectorRef
&
vec_args
,
size_t
input_index
)
{
auto
index
=
input_index
;
for
(
auto
&
arg
:
vec_args
)
{
if
(
utils
::
isa
<
AnfNodePtr
>
(
arg
))
{
// arg is a anf node
auto
node
=
utils
::
cast
<
AnfNodePtr
>
(
arg
);
index
=
SetChildGraphInput
(
graph
,
node
,
input_index
);
}
else
if
(
utils
::
isa
<
ValuePtr
>
(
arg
))
{
// arg is a tensor
auto
value
=
utils
::
cast
<
ValuePtr
>
(
arg
);
index
=
SetChildGraphInput
(
graph
,
value
,
input_index
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Unexpected arg type "
<<
arg
.
ToString
();
}
}
return
index
;
}
void
AscendSession
::
SetChildGraphInput
(
GraphId
g
,
const
VectorRef
&
args
)
{
MS_LOG
(
INFO
)
<<
"Set input of graph "
<<
g
;
auto
to_graph
=
GetGraph
(
g
);
MS_EXCEPTION_IF_NULL
(
to_graph
);
DumpGraphInputArgs
(
args
);
UpdateGraphOrder
(
g
);
std
::
vector
<
AnfNodePtr
>
graph_inputs
=
to_graph
->
inputs
();
auto
valid_inputs
=
to_graph
->
ValidInputs
();
auto
&
graph_inputs
=
to_graph
->
inputs
();
auto
real_args
=
GetRealArgs
(
to_graph
,
args
);
size_t
input_index
=
0
;
for
(
size_t
i
=
0
;
i
<
real_args
.
size
();
i
++
)
{
if
(
input_index
>=
graph_inputs
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"input_index "
<<
input_index
<<
" out of range size "
<<
graph_inputs
.
size
();
}
if
(
utils
::
isa
<
AnfNodePtr
>
(
real_args
[
i
]))
{
auto
&
real_arg
=
real_args
[
i
];
if
(
utils
::
isa
<
AnfNodePtr
>
(
real_arg
))
{
// arg is a anf node
auto
real_arg
=
utils
::
cast
<
AnfNodePtr
>
(
real_args
[
i
]);
auto
real_arg_output_num
=
AnfAlgo
::
GetOutputTensorNum
(
real_arg
);
if
(
!
AnfAlgo
::
CheckPrimitiveType
(
real_arg
,
prim
::
kPrimTupleGetItem
)
&&
real_arg_output_num
>
1
)
{
input_index
+=
real_arg_output_num
;
continue
;
}
if
(
valid_inputs
[
input_index
])
{
SetChildGraphParameter
(
real_arg
,
graph_inputs
[
input_index
]);
}
else
{
MS_LOG
(
DEBUG
)
<<
"Invalid input arg"
<<
real_arg
->
DebugString
();
}
input_index
++
;
}
else
if
(
utils
::
isa
<
ValuePtr
>
(
args
[
i
]))
{
auto
value
=
utils
::
cast
<
ValuePtr
>
(
args
[
i
]);
MS_EXCEPTION_IF_NULL
(
value
);
auto
node
=
utils
::
cast
<
AnfNodePtr
>
(
real_arg
);
input_index
=
SetChildGraphInput
(
to_graph
,
node
,
input_index
);
}
else
if
(
utils
::
isa
<
ValuePtr
>
(
real_arg
))
{
// arg is a tensor
if
(
!
value
->
isa
<
Tensor
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Value Node should be a tensor, unexpected value: "
<<
value
->
ToString
();
}
SetChildGraphParameter
(
value
->
cast
<
TensorPtr
>
(),
graph_inputs
[
input_index
]);
input_index
++
;
auto
value
=
utils
::
cast
<
ValuePtr
>
(
real_arg
);
input_index
=
SetChildGraphInput
(
to_graph
,
value
,
input_index
);
}
else
if
(
utils
::
isa
<
VectorRef
>
(
real_arg
))
{
// arg is a VectorRef
auto
vec_args
=
utils
::
cast
<
VectorRef
>
(
real_arg
);
input_index
=
SetChildGraphInput
(
to_graph
,
vec_args
,
input_index
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Unexpected arg type "
<<
args
[
i
]
.
ToString
();
MS_LOG
(
EXCEPTION
)
<<
"Unexpected arg type "
<<
real_arg
.
ToString
();
}
}
MS_LOG
(
INFO
)
<<
"Finish!"
;
...
...
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
6ae8345c
...
...
@@ -79,6 +79,10 @@ class AscendSession : public SessionBasic {
void
RunOpHardwareOptimize
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
const
;
void
RunOpExecTask
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
size_t
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
size_t
input_index
);
size_t
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
ValuePtr
&
value
,
size_t
input_index
);
size_t
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
VectorRef
&
vec_args
,
size_t
input_index
);
// merge execution order list of child graphs
void
MergeGraphExecOrder
();
// insert assion op to sync data bettween different graphs
...
...
mindspore/ccsrc/session/kernel_graph.h
浏览文件 @
6ae8345c
...
...
@@ -88,7 +88,7 @@ class KernelGraph : public FuncGraph {
void
set_executable
(
bool
executable
)
{
executable_
=
executable
;
}
// set invalid inputs for control sink
std
::
vector
<
bool
>
*
MutableValidInputs
()
{
return
&
valid_inputs_
;
}
std
::
vector
<
bool
>
ValidInputs
()
{
return
valid_inputs_
;
}
const
std
::
vector
<
bool
>
&
ValidInputs
()
const
{
return
valid_inputs_
;
}
private:
// remove value node form graph
...
...
mindspore/ccsrc/utils/base_ref.h
浏览文件 @
6ae8345c
...
...
@@ -228,6 +228,8 @@ T cast(const BaseRef &handle) {
class
VectorRef
:
public
BaseRef
{
public:
using
value_type
=
BaseRef
;
VectorRef
()
{}
explicit
VectorRef
(
const
std
::
vector
<
BaseRef
>
&
elements
)
:
elements_
(
elements
)
{}
VectorRef
(
const
const_iterator
&
begin
,
const
const_iterator
&
end
)
:
elements_
(
begin
,
end
)
{}
...
...
@@ -251,6 +253,13 @@ class VectorRef : public BaseRef {
return
elements_
[
dim
];
}
BaseRef
&
operator
[](
const
std
::
size_t
&
dim
)
{
if
(
dim
>=
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Out of the size of the tuple."
;
}
return
elements_
[
dim
];
}
uint32_t
type
()
const
override
{
return
tid
();
}
std
::
string
ToString
()
const
override
;
std
::
vector
<
BaseRef
>
&
elements
()
{
return
elements_
;
}
...
...
mindspore/ccsrc/vm/backend.cc
浏览文件 @
6ae8345c
...
...
@@ -143,6 +143,66 @@ void MsBackend::SetSwitchGraph() {
}
}
// convert node from formal parameter to actual parameter,
// and actual parameter is graph user's formal parameter.
// get top while graph's parameter in recall while.
AnfNodePtr
MsBackend
::
ConvertGraphInput
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
)
{
std
::
unordered_map
<
AnfNodePtr
,
size_t
>
params_index
;
auto
result
=
node
;
auto
graph
=
result
->
func_graph
();
while
(
func_graph
!=
graph
)
{
auto
iter
=
graph_user_inputs_
.
find
(
graph
);
if
(
iter
==
graph_user_inputs_
.
end
())
{
break
;
}
params_index
.
clear
();
auto
&
params
=
graph
->
parameters
();
for
(
size_t
i
=
0
;
i
<
params
.
size
();
++
i
)
{
params_index
[
params
[
i
]]
=
i
;
}
graph
=
iter
->
second
.
first
;
auto
&
inputs
=
iter
->
second
.
second
;
result
=
inputs
[
params_index
[
result
]];
}
return
result
;
}
void
MsBackend
::
SetGraphUserInputs
(
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
user
,
const
AnfNodePtrList
&
inputs
)
{
if
(
graph_user_inputs_
.
find
(
func_graph
)
!=
graph_user_inputs_
.
end
())
{
return
;
}
graph_user_inputs_
[
func_graph
]
=
{
user
,
inputs
};
}
void
MsBackend
::
RecallGraphInput
(
const
FuncGraphPtr
&
func_graph
,
const
VectorRef
&
args
,
const
BaseRef
&
c
)
{
std
::
unordered_map
<
AnfNodePtr
,
size_t
>
params_index
;
auto
&
params
=
func_graph
->
parameters
();
for
(
size_t
i
=
0
;
i
<
params
.
size
();
++
i
)
{
params_index
[
params
[
i
]]
=
i
;
}
// recall all child graphs in this while
auto
&
graph_inputs
=
graph_inputs_
[
c
];
for
(
auto
&
iter
:
graph_inputs
)
{
auto
&
graph
=
iter
.
first
;
auto
&
old_args
=
iter
.
second
;
auto
&
result
=
graph_id_map_
[
graph
];
auto
&
inputs
=
result
.
inputs
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
input
=
ConvertGraphInput
(
func_graph
,
inputs
[
i
]);
auto
it
=
params_index
.
find
(
input
);
if
(
it
!=
params_index
.
end
())
{
old_args
[
i
]
=
args
[
it
->
second
];
}
}
sess_
->
SetChildGraphInput
(
graph
,
old_args
);
}
graph_inputs_
.
erase
(
c
);
}
// compile set input output
VectorRef
MsBackend
::
MsSimuRunGraph
(
const
GraphId
&
g
,
const
VectorRef
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"set graph input:"
<<
g
;
...
...
@@ -150,13 +210,20 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
sess_
->
SetChildGraphInput
(
g
,
args
);
if
(
is_switch_call_
)
{
bool
curr_cond
=
simu_cond_map_
[
curr_switch_
].
curr_cond
;
MS_LOG
(
DEBUG
)
<<
"switch call MsSimuRunGraph:"
<<
curr_cond
;
if
(
0
==
simu_cond_map_
[
curr_switch_
].
cond_graph_map
.
count
(
curr_cond
))
{
MS_LOG
(
DEBUG
)
<<
"switch call MsSimuRunGraph:"
<<
curr_cond
<<
", "
<<
g
;
simu_cond_map_
[
curr_switch_
].
cond_graph_map
[
curr_cond
]
=
g
;
SetSwitchGraph
();
if
(
!
curr_switch_
.
is_null
())
{
// push this {g, args} to all user while graph_inputs for nest while,
// when current condition recall over delete this cond in graph_inputs.
for
(
auto
&
iter
:
graph_inputs_
)
{
iter
.
second
.
push_back
({
g
,
args
});
}
if
(
graph_inputs_
.
find
(
curr_switch_
)
==
graph_inputs_
.
end
())
{
graph_inputs_
[
curr_switch_
].
push_back
({
g
,
args
});
}
}
bool
curr_cond
=
simu_cond_map_
[
curr_switch_
].
curr_cond
;
MS_LOG
(
DEBUG
)
<<
"switch call MsSimuRunGraph:"
<<
curr_cond
<<
", "
<<
g
;
simu_cond_map_
[
curr_switch_
].
cond_graph_map
[
curr_cond
]
=
g
;
SetSwitchGraph
();
}
std
::
vector
<
BaseRef
>
outputs
;
...
...
@@ -205,42 +272,17 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) {
return
outputs
;
}
void
MsBackend
::
SetSimuCondFlag
(
const
BaseRef
&
c
,
int
flag
)
{
MS_LOG
(
DEBUG
)
<<
"while set cond :"
<<
c
.
ToString
()
<<
", "
<<
simu_cond_map_
.
size
();
if
(
simu_cond_map_
.
find
(
c
)
==
simu_cond_map_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"error c not find"
;
}
simu_cond_map_
[
c
].
flag
=
flag
;
}
int
MsBackend
::
GetSimuCondFlag
(
const
BaseRef
&
c
)
{
BaseRef
cond
=
c
;
if
(
cond
.
is_null
())
{
MS_LOG
(
DEBUG
)
<<
"get curr_switch"
;
cond
=
curr_switch_
;
}
if
(
simu_cond_map_
.
find
(
cond
)
==
simu_cond_map_
.
end
())
{
MS_LOG
(
ERROR
)
<<
"error c not find"
;
return
-
1
;
}
return
simu_cond_map_
[
cond
].
flag
;
}
SwitchCondStatus
MsBackend
::
SetSimuCond
(
const
BaseRef
&
c
,
bool
value
)
{
MS_LOG
(
DEBUG
)
<<
"set cond :"
<<
c
.
ToString
()
<<
", "
<<
simu_cond_map_
.
size
();
CondGraph
cond_graph
;
cond_graph
.
curr_cond
=
value
;
if
(
simu_cond_map_
.
find
(
c
)
==
simu_cond_map_
.
end
())
{
cond_graph
.
flag
=
0
;
simu_cond_map_
[
c
]
=
cond_graph
;
}
if
(
simu_cond_map_
[
c
].
cond_graph_map
.
count
(
value
))
{
if
(
value
==
true
)
{
return
kCondAlreadyRun
;
}
return
kCondAlreadyRun
;
}
simu_cond_map_
[
c
].
curr_cond
=
value
;
MS_LOG
(
DEBUG
)
<<
"end set cond "
;
...
...
mindspore/ccsrc/vm/backend.h
浏览文件 @
6ae8345c
...
...
@@ -16,9 +16,11 @@
#ifndef MINDSPORE_CCSRC_VM_BACKEND_H_
#define MINDSPORE_CCSRC_VM_BACKEND_H_
#include <
string
>
#include <
list
>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include "ir/anf.h"
#include "vm/segment_runner.h"
...
...
@@ -45,6 +47,8 @@ class Backend {
virtual
bool
GetCond
(
const
BaseRef
&
c
,
bool
*
value
);
virtual
void
SetSwitchGraph
()
{}
virtual
void
SetSwitchActive
(
const
BaseRef
&
,
bool
)
{}
virtual
void
RecallGraphInput
(
const
FuncGraphPtr
&
,
const
VectorRef
&
,
const
BaseRef
&
)
{}
virtual
void
SetGraphUserInputs
(
const
FuncGraphPtr
&
,
const
FuncGraphPtr
&
,
const
AnfNodePtrList
&
)
{}
void
set_curr_switch
(
const
BaseRef
&
value
)
{
curr_switch_
=
value
;
...
...
@@ -54,8 +58,6 @@ class Backend {
BaseRef
curr_switch
()
{
return
curr_switch_
;
}
virtual
void
Link
(
GraphId
)
{}
virtual
LinConvertResult
GetMultiGraphRun
(
const
FuncGraphPtr
&
)
{
return
LinConvertResult
();
}
virtual
void
SetSimuCondFlag
(
const
BaseRef
&
,
int
)
{}
virtual
int
GetSimuCondFlag
(
const
BaseRef
&
)
{
return
0
;
}
LinConvertResult
multi_result
()
{
return
multi_result_
;
}
void
set_multi_result
(
const
LinConvertResult
&
value
)
{
multi_result_
=
value
;
}
...
...
@@ -75,11 +77,11 @@ class Backend {
bool
simu_flag_
;
LinConvertResult
multi_result_
;
AnfNodePtr
final_output_
;
std
::
unordered_map
<
FuncGraphPtr
,
std
::
pair
<
FuncGraphPtr
,
AnfNodePtrList
>>
graph_user_inputs_
;
};
struct
CondGraph
{
bool
curr_cond
;
int
flag
;
std
::
unordered_map
<
bool
,
GraphId
>
cond_graph_map
;
};
...
...
@@ -97,15 +99,17 @@ class MsBackend : public Backend {
void
SetSwitchGraph
()
override
;
void
SetSwitchActive
(
const
BaseRef
&
c
,
bool
cond
)
override
;
void
RecallGraphInput
(
const
FuncGraphPtr
&
,
const
VectorRef
&
,
const
BaseRef
&
)
override
;
void
SetGraphUserInputs
(
const
FuncGraphPtr
&
,
const
FuncGraphPtr
&
,
const
AnfNodePtrList
&
)
override
;
void
Link
(
GraphId
)
override
;
AnfNodePtr
ConvertGraphInput
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
);
LinConvertResult
GetMultiGraphRun
(
const
FuncGraphPtr
&
g
)
override
;
void
SetSimuCondFlag
(
const
BaseRef
&
c
,
int
flag
)
override
;
int
GetSimuCondFlag
(
const
BaseRef
&
c
)
override
;
private:
session
::
SessionPtr
sess_
;
std
::
unordered_map
<
BaseRef
,
CondGraph
,
BaseRefHash
>
simu_cond_map_
;
std
::
unordered_map
<
GraphId
,
LinConvertResult
>
graph_id_map_
;
std
::
unordered_map
<
BaseRef
,
std
::
list
<
std
::
pair
<
GraphId
,
VectorRef
>>
,
BaseRefHash
>
graph_inputs_
;
};
}
// namespace compile
}
// namespace mindspore
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
6ae8345c
...
...
@@ -390,6 +390,16 @@ void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) {
void
CompileGraph
::
AddPartial
(
const
CNodePtr
&
node
)
{
auto
inputs
=
node
->
inputs
();
VectorRef
args
;
auto
fn
=
inputs
[
1
];
if
(
!
IsValueNode
<
FuncGraph
>
(
fn
))
{
MS_LOG
(
EXCEPTION
)
<<
"The type of 1st input of node must be FuncGraph"
;
}
if
(
backend_
->
is_multi_graph_sink
())
{
auto
func_graph
=
GetValueNode
<
FuncGraphPtr
>
(
fn
);
args
.
emplace_back
(
func_graph
);
AnfNodePtrList
outs
(
inputs
.
begin
()
+
2
,
inputs
.
end
());
backend_
->
SetGraphUserInputs
(
func_graph
,
node
->
func_graph
(),
outs
);
}
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
i
++
)
{
args
.
emplace_back
(
Ref
(
inputs
[
i
]));
}
...
...
@@ -442,12 +452,17 @@ void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim)
}
int
CompileGraph
::
AddCall
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
node
)
{
auto
node_inputs
=
node
->
inputs
();
AnfNodePtr
fn
=
node_inputs
[
0
];
auto
inputs
=
node
->
inputs
();
AnfNodePtr
fn
=
inputs
[
0
];
if
(
backend_
->
is_multi_graph_sink
()
&&
IsValueNode
<
FuncGraph
>
(
fn
))
{
auto
func_graph
=
GetValueNode
<
FuncGraphPtr
>
(
fn
);
AnfNodePtrList
outs
(
inputs
.
begin
()
+
1
,
inputs
.
end
());
backend_
->
SetGraphUserInputs
(
func_graph
,
node
->
func_graph
(),
outs
);
}
(
void
)
Ref
(
fn
);
size_t
size
=
node_
inputs
.
size
();
size_t
size
=
inputs
.
size
();
for
(
size_t
i
=
size
-
1
;
i
>
0
;
i
--
)
{
AddInput
(
node_
inputs
[
i
]);
AddInput
(
inputs
[
i
]);
}
if
(
node
==
graph
->
output
())
{
AddTailCall
(
fn
,
size
);
...
...
mindspore/ccsrc/vm/vm.cc
浏览文件 @
6ae8345c
...
...
@@ -32,7 +32,8 @@ namespace compile {
// Arguments:
// fn_: Callable function.
// args_: Sequence of function args.
StructPartial
::
StructPartial
(
int
fn
,
const
VectorRef
&
args
)
:
fn_
(
fn
),
args_
(
args
)
{}
// fg_: Graph of function.
StructPartial
::
StructPartial
(
int
fn
,
const
VectorRef
&
args
,
const
FuncGraphPtr
&
fg
)
:
fn_
(
fn
),
args_
(
args
),
fg_
(
fg
)
{}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
StructPartial
&
other
)
{
os
<<
"partial("
<<
other
.
fn_
<<
", "
<<
other
.
args_
.
ToString
()
<<
")"
;
...
...
@@ -40,7 +41,7 @@ std::ostream &operator<<(std::ostream &os, const StructPartial &other) {
}
bool
operator
==
(
const
StructPartial
&
lhs
,
const
StructPartial
&
rhs
)
{
return
(
lhs
.
fn_
==
rhs
.
fn_
&&
lhs
.
args_
==
rhs
.
args_
);
return
(
lhs
.
fn_
==
rhs
.
fn_
&&
lhs
.
args_
==
rhs
.
args_
&&
lhs
.
fg_
==
rhs
.
fg_
);
}
StructSimuSwitch
::
StructSimuSwitch
(
const
BaseRef
&
fn
,
const
BaseRef
&
value
)
:
fn_
(
fn
),
value_
(
value
)
{}
...
...
@@ -242,16 +243,6 @@ void FinalVM::InstTailCall(const VectorRef &args) {
int
nargs
=
utils
::
cast
<
int
>
(
args
[
2
]);
auto
new_jmp
=
Ref
(
jmp
);
if
(
backend_
->
simu_flag
())
{
if
(
backend_
->
GetSimuCondFlag
(
BaseRef
())
==
2
)
{
MS_LOG
(
DEBUG
)
<<
"invoke while call tail first"
;
Pop
(
height
);
Push
(
1
);
Popp
();
return
;
}
}
MoveStack
(
nargs
,
height
);
MS_LOG
(
DEBUG
)
<<
"TailCall pushp:"
<<
pc_
<<
", jmp:"
<<
jmp
;
DoJmp
(
new_jmp
);
...
...
@@ -291,8 +282,30 @@ void FinalVM::InstReturn(const VectorRef &args) {
MS_LOG
(
DEBUG
)
<<
"End"
;
}
void
FinalVM
::
InstPartial
(
const
VectorRef
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"Start"
;
void
FinalVM
::
InstSimuPartial
(
const
VectorRef
&
args
)
{
const
size_t
args_size
=
2
;
if
(
args
.
size
()
<
args_size
)
{
MS_LOG
(
ERROR
)
<<
__FUNCTION__
<<
" requires "
<<
args_size
<<
" or more parameters, while the input size is "
<<
args
.
size
()
<<
"."
;
return
;
}
auto
&
node
=
args
[
0
];
if
(
!
utils
::
isa
<
FuncGraphPtr
>
(
node
))
{
MS_LOG
(
ERROR
)
<<
"The type of 1st input of node must be FuncGraph"
;
return
;
}
auto
fg
=
utils
::
cast
<
FuncGraphPtr
>
(
node
);
int
fn_
=
utils
::
cast
<
int
>
(
args
[
1
]);
auto
fn
=
utils
::
cast
<
int
>
(
Ref
(
fn_
));
MS_LOG
(
DEBUG
)
<<
"Partial argssize:"
<<
args
.
size
();
std
::
vector
<
BaseRef
>
outs
(
args
.
size
()
-
2
);
(
void
)
std
::
transform
(
args
.
begin
()
+
2
,
args
.
end
(),
outs
.
begin
(),
[
&
,
this
](
const
BaseRef
&
a
)
{
return
Ref
(
utils
::
cast
<
int
>
(
a
));
});
Push
(
std
::
make_shared
<
StructPartial
>
(
fn
,
VectorRef
(
outs
),
fg
));
}
void
FinalVM
::
InstRealPartial
(
const
VectorRef
&
args
)
{
const
size_t
args_size
=
1
;
if
(
args
.
size
()
<
args_size
)
{
MS_LOG
(
ERROR
)
<<
__FUNCTION__
<<
" requires "
<<
args_size
<<
" or more parameters, while the input size is "
...
...
@@ -304,10 +317,18 @@ void FinalVM::InstPartial(const VectorRef &args) {
auto
fn
=
utils
::
cast
<
int
>
(
Ref
(
fn_
));
MS_LOG
(
DEBUG
)
<<
"Partial argssize:"
<<
args
.
size
();
std
::
vector
<
BaseRef
>
outs
(
args
.
size
()
-
1
);
(
void
)
std
::
transform
(
args
.
begin
()
+
1
,
args
.
end
(),
outs
.
begin
(),
[
&
,
this
](
const
BaseRef
&
a
)
{
return
Ref
(
utils
::
cast
<
int
>
(
a
));
});
Push
(
std
::
make_shared
<
StructPartial
>
(
fn
,
VectorRef
(
outs
)));
}
void
FinalVM
::
InstPartial
(
const
VectorRef
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"Start"
;
if
(
backend_
->
is_multi_graph_sink
())
{
InstSimuPartial
(
args
);
}
else
{
InstRealPartial
(
args
);
}
MS_LOG
(
DEBUG
)
<<
"End"
;
}
...
...
@@ -328,43 +349,57 @@ void FinalVM::InstSimuSwitch(const VectorRef &args) {
bool
bool_value
=
cond
;
SwitchCondStatus
cond_stat
=
backend_
->
SetSimuCond
(
c
,
bool_value
);
int
cond_flag
=
backend_
->
GetSimuCondFlag
(
c
);
MS_LOG
(
DEBUG
)
<<
"Simu switch cond:"
<<
cond
<<
", "
<<
cond_flag
<<
", "
<<
c
.
cast
<
AnfNodePtr
>
()
->
DebugString
();
if
(
cond_flag
==
2
)
{
Popp
();
Popp
();
backend_
->
SetSimuCondFlag
(
c
,
0
);
return
;
}
if
(
cond_stat
==
kCondAlreadyRun
)
{
MS_LOG
(
DEBUG
)
<<
"switch alreay run bool while true jmp"
;
if
(
cond_flag
==
0
)
{
MS_LOG
(
DEBUG
)
<<
"switch second run bool while true jmp"
;
backend_
->
SetSwitchActive
(
c
,
true
);
Push
(
std
::
make_shared
<
StructSimuSwitch
>
(
Ref
(
vtrue
),
c
));
Pushsp
();
backend_
->
SetSimuCondFlag
(
c
,
1
);
return
;
}
else
if
(
cond_flag
==
1
)
{
MS_LOG
(
DEBUG
)
<<
"switch first run bool while if jmp"
;
Push
(
std
::
make_shared
<
StructSimuSwitch
>
(
Ref
(
vfalse
),
c
));
(
void
)
backend_
->
SetSimuCond
(
c
,
false
);
backend_
->
SetSimuCondFlag
(
c
,
2
);
return
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"error cond not find"
;
return
;
BaseRef
jmp
=
Ref
(
vtrue
);
if
(
utils
::
isa
<
StructPartial
>
(
jmp
))
{
auto
new_jmp
=
utils
::
cast
<
std
::
shared_ptr
<
StructPartial
>>
(
jmp
);
backend_
->
RecallGraphInput
(
new_jmp
->
fg_
,
new_jmp
->
args_
,
c
);
}
cond_jmp_
[
c
]
=
Ref
(
vfalse
);
Push
(
static_cast
<
int
>
(
cond_stat
));
Popp
();
backend_
->
SetSwitchActive
(
c
,
bool_value
);
return
;
}
if
(
bool_value
)
{
Push
(
std
::
make_shared
<
StructSimuSwitch
>
(
Ref
(
vtrue
),
c
));
Pushsp
();
}
else
{
MergeJmpArgs
(
Ref
(
vfalse
),
c
);
Push
(
std
::
make_shared
<
StructSimuSwitch
>
(
Ref
(
vfalse
),
c
));
}
}
void
FinalVM
::
MergeJmpArgs
(
const
BaseRef
&
jmp
,
const
BaseRef
&
c
)
{
auto
iter
=
cond_jmp_
.
find
(
c
);
if
(
iter
==
cond_jmp_
.
end
())
{
return
;
}
auto
old_jmp
=
utils
::
cast
<
std
::
shared_ptr
<
StructPartial
>>
(
iter
->
second
);
auto
new_jmp
=
utils
::
cast
<
std
::
shared_ptr
<
StructPartial
>>
(
jmp
);
auto
&
old_args
=
old_jmp
->
args_
;
auto
&
new_args
=
new_jmp
->
args_
;
for
(
size_t
i
=
0
;
i
<
new_args
.
size
();
++
i
)
{
auto
&
old_arg
=
old_args
[
i
];
auto
&
new_arg
=
new_args
[
i
];
if
(
utils
::
isa
<
VectorRef
>
(
old_arg
))
{
auto
old_vec_ref
=
utils
::
cast
<
VectorRef
>
(
old_arg
);
if
(
utils
::
isa
<
VectorRef
>
(
new_arg
))
{
auto
new_vec_ref
=
utils
::
cast
<
VectorRef
>
(
new_arg
);
std
::
copy
(
new_vec_ref
.
begin
(),
new_vec_ref
.
end
(),
std
::
back_inserter
(
old_vec_ref
));
}
new_arg
=
old_vec_ref
;
}
else
if
(
utils
::
isa
<
VectorRef
>
(
new_arg
))
{
auto
new_vec_ref
=
utils
::
cast
<
VectorRef
>
(
new_arg
);
new_vec_ref
.
push_back
(
old_arg
);
new_arg
=
new_vec_ref
;
}
else
{
new_arg
=
VectorRef
({
new_arg
,
old_arg
});
}
}
}
void
FinalVM
::
InstRealSwitch
(
const
VectorRef
&
args
)
{
const
size_t
args_size
=
3
;
if
(
args
.
size
()
!=
args_size
)
{
...
...
@@ -399,6 +434,7 @@ void FinalVM::InstSwitch(const VectorRef &args) {
}
else
{
InstRealSwitch
(
args
);
}
MS_LOG
(
DEBUG
)
<<
"End"
;
}
void
FinalVM
::
InstTuple
(
const
VectorRef
&
args
)
{
...
...
mindspore/ccsrc/vm/vm.h
浏览文件 @
6ae8345c
...
...
@@ -27,6 +27,9 @@
#include <utility>
#include <vector>
#include <deque>
#include <unordered_map>
#include "ir/anf.h"
#include "utils/base_ref.h"
namespace
mindspore
{
...
...
@@ -60,13 +63,14 @@ const std::vector<std::string> inst_str{"call", "tail_call", "return", "partial
class
StructPartial
:
public
Base
{
public:
// Initialize StructPartial.
StructPartial
(
int
fn
,
const
VectorRef
&
args
);
StructPartial
(
int
fn
,
const
VectorRef
&
args
,
const
FuncGraphPtr
&
fg
=
nullptr
);
virtual
~
StructPartial
()
=
default
;
MS_DECLARE_PARENT
(
StructPartial
,
Base
)
int
fn_
;
VectorRef
args_
;
FuncGraphPtr
fg_
;
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
StructPartial
&
other
);
...
...
@@ -98,6 +102,8 @@ class FinalVM {
void
InstTailCall
(
const
VectorRef
&
args
);
void
InstReturn
(
const
VectorRef
&
args
);
void
InstPartial
(
const
VectorRef
&
args
);
void
InstSimuPartial
(
const
VectorRef
&
args
);
void
InstRealPartial
(
const
VectorRef
&
args
);
void
InstSwitch
(
const
VectorRef
&
args
);
void
InstSimuSwitch
(
const
VectorRef
&
args
);
void
InstRealSwitch
(
const
VectorRef
&
args
);
...
...
@@ -120,6 +126,7 @@ class FinalVM {
void
Pushsp
();
void
Popsp
();
void
DoJmp
(
const
BaseRef
&
jmp
);
void
MergeJmpArgs
(
const
BaseRef
&
jmp
,
const
BaseRef
&
c
);
private:
InstSet
insts_
;
...
...
@@ -128,6 +135,7 @@ class FinalVM {
std
::
stack
<
int
>
retsp_
;
int
pc_
;
int
sp_
;
std
::
unordered_map
<
BaseRef
,
BaseRef
,
BaseRefHash
>
cond_jmp_
;
BackendPtr
backend_
;
const
InstFunctionMap
inst_function_map
=
{
{
Instruction
::
kCall
,
[
this
](
const
VectorRef
&
args
)
{
InstCall
(
args
);
}},
...
...
tests/st/control/test_multigraph_sink.py
0 → 100644
浏览文件 @
6ae8345c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_multigraph_sink """
import
pytest
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.context
as
context
from
mindspore.common.tensor
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
ms_function
from
mindspore.ops
import
operations
as
P
def
setup_module
(
module
):
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
save_graphs
=
True
,
device_target
=
"Ascend"
)
context
.
set_context
(
enable_task_sink
=
True
,
device_id
=
0
)
c1
=
Tensor
([
2
],
mstype
.
int32
)
c2
=
Tensor
([
14
],
mstype
.
int32
)
c3
=
Tensor
([
1
],
mstype
.
int32
)
c4
=
Tensor
([
0
],
mstype
.
int32
)
c5
=
Tensor
([
14
],
mstype
.
int32
)
@
ms_function
def
simple_if
(
x
,
y
,
z
):
if
x
<
y
:
x
=
x
+
1
else
:
x
=
x
+
2
x
=
x
+
3
return
x
@
ms_function
def
if_by_if
(
x
,
y
,
z
):
if
x
<
y
:
x
=
x
+
1
if
y
>
x
:
x
=
x
+
2
x
=
x
+
3
return
x
@
ms_function
def
if_in_if
(
x
,
y
,
z
):
out
=
c4
if
x
<
y
:
z
=
c4
+
c4
if
z
<
y
:
z
=
z
+
2
out
=
out
+
z
x
=
x
+
3
out
=
out
+
x
return
out
@
ms_function
def
simple_while
(
x
,
y
,
z
):
y
=
y
+
4
while
x
<
y
:
x
=
x
+
1
x
=
x
+
3
return
x
@
ms_function
def
while_by_while
(
x
,
y
,
z
):
while
x
<
y
:
x
=
x
+
1
while
z
<
c5
:
z
=
z
+
1
x
=
x
+
1
x
=
x
+
1
return
x
@
ms_function
def
while_in_while
(
x
,
y
,
z
):
out
=
c4
while
x
<
y
:
z
=
c4
+
c4
while
z
<
y
:
z
=
z
+
1
out
=
out
+
z
x
=
x
+
1
out
=
out
+
x
return
out
@
ms_function
def
while_by_while_in_while
(
x
,
y
,
z
):
out
=
c4
while
x
<
c2
:
y
=
c4
+
c4
while
y
<
c2
:
y
=
y
+
1
out
=
out
+
y
z
=
c4
+
c4
while
z
<
c2
:
z
=
z
+
1
out
=
out
+
z
x
=
x
+
1
out
=
out
+
x
return
out
@
ms_function
def
while_in_while_in_while
(
x
,
y
,
z
):
out
=
c4
while
x
<
c2
:
y
=
c4
+
c4
while
y
<
c2
:
y
=
y
+
1
z
=
c4
+
c4
while
z
<
c2
:
z
=
z
+
1
out
=
out
+
z
out
=
out
+
y
x
=
x
+
1
out
=
out
+
x
return
out
def
test_simple_if
():
output
=
simple_if
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
6
],
mstype
.
int32
)
assert
output
==
expect
def
test_if_by_if
():
output
=
if_by_if
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
8
],
mstype
.
int32
)
assert
output
==
expect
def
test_if_in_if
():
output
=
if_in_if
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
7
],
mstype
.
int32
)
assert
output
==
expect
def
test_simple_while
():
output
=
simple_while
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
21
],
mstype
.
int32
)
assert
output
==
expect
def
test_while_by_while
():
output
=
while_by_while
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
28
],
mstype
.
int32
)
assert
output
==
expect
def
test_while_in_while
():
output
=
while_in_while
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
1274
],
mstype
.
int32
)
assert
output
==
expect
def
test_while_by_while_in_while
():
output
=
while_by_while_in_while
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
350
],
mstype
.
int32
)
assert
output
==
expect
def
test_while_in_while_in_while
():
output
=
while_in_while_in_while
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
2534
],
mstype
.
int32
)
assert
output
==
expect
tests/ut/python/pynative_mode/test_multigraph_sink.py
0 → 100644
浏览文件 @
6ae8345c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_multigraph_sink """
import
pytest
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.context
as
context
from
mindspore.common.tensor
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
ms_function
from
mindspore.ops
import
operations
as
P
def
setup_module
(
module
):
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
save_graphs
=
True
,
device_target
=
"Ascend"
)
context
.
set_context
(
enable_task_sink
=
True
,
device_id
=
0
)
c1
=
Tensor
([
2
],
mstype
.
int32
)
c2
=
Tensor
([
14
],
mstype
.
int32
)
c3
=
Tensor
([
1
],
mstype
.
int32
)
c4
=
Tensor
([
0
],
mstype
.
int32
)
c5
=
Tensor
([
14
],
mstype
.
int32
)
@
ms_function
def
simple_if
(
x
,
y
,
z
):
if
x
<
y
:
x
=
x
+
1
else
:
x
=
x
+
2
x
=
x
+
3
return
x
@
ms_function
def
if_by_if
(
x
,
y
,
z
):
if
x
<
y
:
x
=
x
+
1
if
y
>
x
:
x
=
x
+
2
x
=
x
+
3
return
x
@
ms_function
def
if_in_if
(
x
,
y
,
z
):
out
=
c4
if
x
<
y
:
z
=
c4
+
c4
if
z
<
y
:
z
=
z
+
2
out
=
out
+
z
x
=
x
+
3
out
=
out
+
x
return
out
@
ms_function
def
simple_while
(
x
,
y
,
z
):
y
=
y
+
4
while
x
<
y
:
x
=
x
+
1
x
=
x
+
3
return
x
@
ms_function
def
while_by_while
(
x
,
y
,
z
):
while
x
<
y
:
x
=
x
+
1
while
z
<
c5
:
z
=
z
+
1
x
=
x
+
1
x
=
x
+
1
return
x
def
test_simple_if
():
output
=
simple_if
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
6
],
mstype
.
int32
)
assert
output
==
expect
def
test_if_by_if
():
output
=
if_by_if
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
8
],
mstype
.
int32
)
assert
output
==
expect
def
test_if_in_if
():
output
=
if_in_if
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
7
],
mstype
.
int32
)
assert
output
==
expect
def
test_simple_while
():
output
=
simple_while
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
21
],
mstype
.
int32
)
assert
output
==
expect
def
test_while_by_while
():
output
=
while_by_while
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
28
],
mstype
.
int32
)
assert
output
==
expect
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录