Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
61464245
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
61464245
编写于
4月 30, 2020
作者:
R
rick_sanchez
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor vm module for multigraph sink
上级
2860fd93
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
121 addition
and
39 deletion
+121
-39
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+49
-24
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+4
-0
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+1
-1
mindspore/ccsrc/vm/vm.cc
mindspore/ccsrc/vm/vm.cc
+56
-14
mindspore/ccsrc/vm/vm.h
mindspore/ccsrc/vm/vm.h
+5
-0
tests/st/control/test_multigraph_sink.py
tests/st/control/test_multigraph_sink.py
+6
-0
未找到文件。
mindspore/ccsrc/session/ascend_session.cc
100755 → 100644
浏览文件 @
61464245
...
@@ -564,42 +564,67 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP
...
@@ -564,42 +564,67 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP
return
create_parameter_from_cnode
(
output_item_with_index
.
first
,
output_item_with_index
.
second
);
return
create_parameter_from_cnode
(
output_item_with_index
.
first
,
output_item_with_index
.
second
);
}
}
void
AscendSession
::
SetFinalGraphOutput
(
const
BaseRef
&
output
)
{
void
AscendSession
::
SetFinalGraphOutput
(
const
AnfNodePtr
&
node
)
{
auto
final_graph
=
GetGraph
(
final_graph_id_
);
MS_EXCEPTION_IF_NULL
(
final_graph
);
if
(
!
utils
::
isa
<
AnfNodePtr
>
(
output
))
{
if
(
!
utils
::
isa
<
ValuePtr
>
(
output
))
{
MS_LOG
(
EXCEPTION
)
<<
"Unknown output type:"
<<
output
.
ToString
();
}
auto
value_ptr
=
utils
::
cast
<
ValuePtr
>
(
output
);
auto
value_node
=
NewValueNode
(
value_ptr
);
MS_EXCEPTION_IF_NULL
(
value_node
);
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
value_node
->
set_kernel_info
(
kernel_info
);
value_node
->
set_abstract
(
abstract
::
FromValue
(
value_ptr
));
final_graph
->
set_output
(
final_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimMakeTuple
),
value_node
}));
final_graph
->
set_executable
(
false
);
MS_LOG
(
INFO
)
<<
"Not anf output["
<<
output
.
ToString
()
<<
"]"
;
return
;
}
// get the backend anf node related to the output node of front
// get the backend anf node related to the output node of front
auto
output_anf_node
=
utils
::
cast
<
AnfNodePtr
>
(
output
);
auto
output_from_graph_id
=
GetGraphIdByNode
(
node
);
auto
output_from_graph_id
=
GetGraphIdByNode
(
output_anf_node
);
auto
output_from_graph
=
GetGraph
(
output_from_graph_id
);
auto
output_from_graph
=
GetGraph
(
output_from_graph_id
);
MS_EXCEPTION_IF_NULL
(
output_anf_
node
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_LOG
(
INFO
)
<<
"Set the output["
<<
output_anf_
node
->
DebugString
()
<<
"] of graph["
<<
output_from_graph_id
MS_LOG
(
INFO
)
<<
"Set the output["
<<
node
->
DebugString
()
<<
"] of graph["
<<
output_from_graph_id
<<
"] to final graph"
;
<<
"] to final graph"
;
MS_EXCEPTION_IF_NULL
(
output_from_graph
);
MS_EXCEPTION_IF_NULL
(
output_from_graph
);
auto
final_graph
=
GetGraph
(
final_graph_id_
);
MS_EXCEPTION_IF_NULL
(
final_graph
);
// if output is from final graph,it remarks no child graph exist
// if output is from final graph,it remarks no child graph exist
if
(
final_graph_id_
==
output_from_graph_id
)
{
if
(
final_graph_id_
==
output_from_graph_id
)
{
MS_LOG
(
INFO
)
<<
"No child graph,output is "
<<
output_anf_
node
->
DebugString
();
MS_LOG
(
INFO
)
<<
"No child graph,output is "
<<
node
->
DebugString
();
final_graph
->
set_output
(
ConstructOutput
({
output_anf_
node
},
final_graph
));
final_graph
->
set_output
(
ConstructOutput
({
node
},
final_graph
));
final_graph
->
set_executable
(
false
);
final_graph
->
set_executable
(
false
);
return
;
return
;
}
}
final_graph
->
set_output
(
output_from_graph
->
output
());
final_graph
->
set_output
(
output_from_graph
->
output
());
}
}
void
AscendSession
::
SetFinalGraphOutput
(
const
ValuePtr
&
value
)
{
auto
value_node
=
NewValueNode
(
value
);
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
value_node
->
set_kernel_info
(
kernel_info
);
value_node
->
set_abstract
(
abstract
::
FromValue
(
value
));
auto
final_graph
=
GetGraph
(
final_graph_id_
);
MS_EXCEPTION_IF_NULL
(
final_graph
);
final_graph
->
set_output
(
final_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimMakeTuple
),
value_node
}));
final_graph
->
set_executable
(
false
);
MS_LOG
(
INFO
)
<<
"Not anf output["
<<
value
->
ToString
()
<<
"]"
;
}
void
AscendSession
::
SetFinalGraphOutput
(
const
VectorRef
&
vec_output
)
{
for
(
auto
&
output
:
vec_output
)
{
if
(
utils
::
isa
<
AnfNodePtr
>
(
output
))
{
auto
output_anf_node
=
utils
::
cast
<
AnfNodePtr
>
(
output
);
SetFinalGraphOutput
(
output_anf_node
);
}
else
if
(
utils
::
isa
<
ValuePtr
>
(
output
))
{
auto
value
=
utils
::
cast
<
ValuePtr
>
(
output
);
SetFinalGraphOutput
(
value
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Unknown output type:"
<<
output
.
ToString
();
}
}
}
void
AscendSession
::
SetFinalGraphOutput
(
const
BaseRef
&
output
)
{
if
(
utils
::
isa
<
AnfNodePtr
>
(
output
))
{
auto
output_anf_node
=
utils
::
cast
<
AnfNodePtr
>
(
output
);
SetFinalGraphOutput
(
output_anf_node
);
}
else
if
(
utils
::
isa
<
ValuePtr
>
(
output
))
{
auto
value
=
utils
::
cast
<
ValuePtr
>
(
output
);
SetFinalGraphOutput
(
value
);
}
else
if
(
utils
::
isa
<
VectorRef
>
(
output
))
{
auto
vec_output
=
utils
::
cast
<
VectorRef
>
(
output
);
SetFinalGraphOutput
(
vec_output
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Unknown output type:"
<<
output
.
ToString
();
}
}
KernelGraphPtr
AscendSession
::
GetGraph
(
mindspore
::
GraphId
graph_id
)
{
KernelGraphPtr
AscendSession
::
GetGraph
(
mindspore
::
GraphId
graph_id
)
{
auto
it
=
graphs_
.
find
(
graph_id
);
auto
it
=
graphs_
.
find
(
graph_id
);
if
(
it
==
graphs_
.
end
())
{
if
(
it
==
graphs_
.
end
())
{
...
...
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
61464245
...
@@ -88,6 +88,10 @@ class AscendSession : public SessionBasic {
...
@@ -88,6 +88,10 @@ class AscendSession : public SessionBasic {
size_t
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
ValuePtr
&
value
,
size_t
input_index
);
size_t
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
ValuePtr
&
value
,
size_t
input_index
);
size_t
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
VectorRef
&
vec_args
,
size_t
input_index
);
size_t
SetChildGraphInput
(
const
KernelGraphPtr
&
graph
,
const
VectorRef
&
vec_args
,
size_t
input_index
);
void
SetFinalGraphOutput
(
const
AnfNodePtr
&
node
);
void
SetFinalGraphOutput
(
const
ValuePtr
&
value
);
void
SetFinalGraphOutput
(
const
VectorRef
&
vec_output
);
// merge execution order list of child graphs
// merge execution order list of child graphs
void
MergeGraphExecOrder
();
void
MergeGraphExecOrder
();
// insert assion op to sync data bettween different graphs
// insert assion op to sync data bettween different graphs
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
61464245
...
@@ -243,7 +243,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
...
@@ -243,7 +243,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
AddInst
(
Instruction
::
kCall
,
args
);
AddInst
(
Instruction
::
kCall
,
args
);
args
.
clear
();
args
.
clear
();
args
.
emplace_back
(
true
);
args
.
emplace_back
(
node
->
input
(
1
)
);
AddInst
(
Instruction
::
kSwitchReturn
,
args
);
AddInst
(
Instruction
::
kSwitchReturn
,
args
);
args
.
clear
();
args
.
clear
();
...
...
mindspore/ccsrc/vm/vm.cc
浏览文件 @
61464245
...
@@ -141,17 +141,31 @@ void FinalVM::Popsp() {
...
@@ -141,17 +141,31 @@ void FinalVM::Popsp() {
}
}
}
}
void
FinalVM
::
PushStatus
(
bool
is_switch_call
)
{
ret_status_
.
push
(
is_switch_call
);
}
bool
FinalVM
::
PopStatus
()
{
if
(
ret_status_
.
empty
())
{
return
false
;
}
bool
status
=
ret_status_
.
top
();
ret_status_
.
pop
();
return
status
;
}
void
FinalVM
::
DoJmp
(
const
BaseRef
&
jmp_orig
)
{
void
FinalVM
::
DoJmp
(
const
BaseRef
&
jmp_orig
)
{
MS_LOG
(
DEBUG
)
<<
"Start"
;
MS_LOG
(
DEBUG
)
<<
"Start"
;
BaseRef
jmp
=
jmp_orig
;
BaseRef
jmp
=
jmp_orig
;
if
(
backend_
->
simu_flag
())
{
if
(
backend_
->
simu_flag
())
{
bool
is_switch_call
=
false
;
if
(
utils
::
isa
<
StructSimuSwitch
>
(
jmp
))
{
// need to inherit from Base
if
(
utils
::
isa
<
StructSimuSwitch
>
(
jmp
))
{
// need to inherit from Base
MS_LOG
(
DEBUG
)
<<
"Start jump StructSwitch"
;
MS_LOG
(
DEBUG
)
<<
"Start jump StructSwitch"
;
auto
simu_value
=
utils
::
cast
<
std
::
shared_ptr
<
StructSimuSwitch
>>
(
jmp
);
auto
simu_value
=
utils
::
cast
<
std
::
shared_ptr
<
StructSimuSwitch
>>
(
jmp
);
jmp
=
simu_value
->
fn_
;
jmp
=
simu_value
->
fn_
;
backend_
->
set_curr_switch
(
simu_value
->
value_
);
backend_
->
set_curr_switch
(
simu_value
->
value_
);
is_switch_call
=
true
;
}
}
PushStatus
(
is_switch_call
);
}
}
if
(
utils
::
isa
<
StructPartial
>
(
jmp
))
{
// need to inherit from Base
if
(
utils
::
isa
<
StructPartial
>
(
jmp
))
{
// need to inherit from Base
...
@@ -255,6 +269,13 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) {
...
@@ -255,6 +269,13 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) {
MS_LOG
(
ERROR
)
<<
__FUNCTION__
<<
" requires one parameter, while the input size is "
<<
args
.
size
()
<<
"."
;
MS_LOG
(
ERROR
)
<<
__FUNCTION__
<<
" requires one parameter, while the input size is "
<<
args
.
size
()
<<
"."
;
return
;
return
;
}
}
auto
rv
=
Ref
(
-
1
);
if
(
utils
::
isa
<
AnfNodePtr
>
(
rv
)
||
utils
::
isa
<
VectorRef
>
(
rv
))
{
auto
&
c
=
args
[
0
];
cond_out_
[
c
]
=
rv
;
}
Pop
(
1
);
Pop
(
1
);
Popsp
();
Popsp
();
}
}
...
@@ -272,8 +293,20 @@ void FinalVM::InstReturn(const VectorRef &args) {
...
@@ -272,8 +293,20 @@ void FinalVM::InstReturn(const VectorRef &args) {
int
height
=
utils
::
cast
<
int
>
(
args
[
1
]);
int
height
=
utils
::
cast
<
int
>
(
args
[
1
]);
auto
rv
=
Ref
(
rpos
);
auto
rv
=
Ref
(
rpos
);
if
(
backend_
->
simu_flag
()
&&
backend_
->
is_switch_call
())
{
if
(
backend_
->
simu_flag
())
{
backend_
->
SetSwitchGraph
();
auto
c
=
backend_
->
curr_switch
();
auto
status
=
PopStatus
();
if
(
status
)
{
auto
iter
=
cond_out_
.
find
(
c
);
if
(
iter
!=
cond_out_
.
end
())
{
rv
=
MergeArgs
(
rv
,
iter
->
second
);
cond_out_
.
erase
(
iter
);
}
}
if
(
backend_
->
is_switch_call
())
{
backend_
->
SetSwitchGraph
();
}
}
}
Pop
(
height
);
Pop
(
height
);
...
@@ -383,21 +416,30 @@ void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) {
...
@@ -383,21 +416,30 @@ void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) {
for
(
size_t
i
=
0
;
i
<
new_args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
new_args
.
size
();
++
i
)
{
auto
&
old_arg
=
old_args
[
i
];
auto
&
old_arg
=
old_args
[
i
];
auto
&
new_arg
=
new_args
[
i
];
auto
&
new_arg
=
new_args
[
i
];
if
(
utils
::
isa
<
VectorRef
>
(
old_arg
))
{
new_arg
=
MergeArgs
(
old_arg
,
new_arg
);
auto
old_vec_ref
=
utils
::
cast
<
VectorRef
>
(
old_arg
);
}
if
(
utils
::
isa
<
VectorRef
>
(
new_arg
))
{
}
auto
new_vec_ref
=
utils
::
cast
<
VectorRef
>
(
new_arg
);
std
::
copy
(
new_vec_ref
.
begin
(),
new_vec_ref
.
end
(),
std
::
back_inserter
(
old_vec_ref
));
BaseRef
FinalVM
::
MergeArgs
(
const
BaseRef
&
first
,
const
BaseRef
&
second
)
{
}
MS_LOG
(
DEBUG
)
<<
__FUNCTION__
<<
": "
<<
first
.
ToString
()
<<
", "
<<
second
.
ToString
();
new_arg
=
old_vec_ref
;
if
(
utils
::
isa
<
VectorRef
>
(
first
))
{
}
else
if
(
utils
::
isa
<
VectorRef
>
(
new_arg
))
{
auto
old_vec_ref
=
utils
::
cast
<
VectorRef
>
(
first
);
auto
new_vec_ref
=
utils
::
cast
<
VectorRef
>
(
new_arg
);
if
(
utils
::
isa
<
VectorRef
>
(
second
))
{
new_vec_ref
.
push_back
(
old_arg
);
auto
new_vec_ref
=
utils
::
cast
<
VectorRef
>
(
second
);
new_arg
=
new_vec_ref
;
std
::
copy
(
new_vec_ref
.
begin
(),
new_vec_ref
.
end
(),
std
::
back_inserter
(
old_vec_ref
))
;
}
else
{
}
else
{
new_arg
=
VectorRef
({
new_arg
,
old_arg
}
);
old_vec_ref
.
push_back
(
second
);
}
}
return
old_vec_ref
;
}
if
(
utils
::
isa
<
VectorRef
>
(
second
))
{
auto
new_vec_ref
=
utils
::
cast
<
VectorRef
>
(
second
);
new_vec_ref
.
push_back
(
first
);
return
new_vec_ref
;
}
}
return
VectorRef
({
first
,
second
});
}
}
void
FinalVM
::
InstRealSwitch
(
const
VectorRef
&
args
)
{
void
FinalVM
::
InstRealSwitch
(
const
VectorRef
&
args
)
{
...
...
mindspore/ccsrc/vm/vm.h
浏览文件 @
61464245
...
@@ -125,17 +125,22 @@ class FinalVM {
...
@@ -125,17 +125,22 @@ class FinalVM {
void
Popp
();
void
Popp
();
void
Pushsp
();
void
Pushsp
();
void
Popsp
();
void
Popsp
();
void
PushStatus
(
bool
is_switch_call
);
bool
PopStatus
();
void
DoJmp
(
const
BaseRef
&
jmp
);
void
DoJmp
(
const
BaseRef
&
jmp
);
void
MergeJmpArgs
(
const
BaseRef
&
jmp
,
const
BaseRef
&
c
);
void
MergeJmpArgs
(
const
BaseRef
&
jmp
,
const
BaseRef
&
c
);
BaseRef
MergeArgs
(
const
BaseRef
&
first
,
const
BaseRef
&
second
);
private:
private:
InstSet
insts_
;
InstSet
insts_
;
std
::
deque
<
BaseRef
>
insts_stack_
;
std
::
deque
<
BaseRef
>
insts_stack_
;
std
::
stack
<
int
>
retp_
;
std
::
stack
<
int
>
retp_
;
std
::
stack
<
int
>
retsp_
;
std
::
stack
<
int
>
retsp_
;
std
::
stack
<
bool
>
ret_status_
;
int
pc_
;
int
pc_
;
int
sp_
;
int
sp_
;
std
::
unordered_map
<
BaseRef
,
BaseRef
,
BaseRefHash
>
cond_jmp_
;
std
::
unordered_map
<
BaseRef
,
BaseRef
,
BaseRefHash
>
cond_jmp_
;
std
::
unordered_map
<
BaseRef
,
BaseRef
,
BaseRefHash
>
cond_out_
;
BackendPtr
backend_
;
BackendPtr
backend_
;
const
InstFunctionMap
inst_function_map
=
{
const
InstFunctionMap
inst_function_map
=
{
{
Instruction
::
kCall
,
[
this
](
const
VectorRef
&
args
)
{
InstCall
(
args
);
}},
{
Instruction
::
kCall
,
[
this
](
const
VectorRef
&
args
)
{
InstCall
(
args
);
}},
...
...
tests/st/control/test_multigraph_sink.py
浏览文件 @
61464245
...
@@ -26,6 +26,7 @@ from mindspore.ops import operations as P
...
@@ -26,6 +26,7 @@ from mindspore.ops import operations as P
def
setup_module
(
module
):
def
setup_module
(
module
):
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
c1
=
Tensor
([
2
],
mstype
.
int32
)
c1
=
Tensor
([
2
],
mstype
.
int32
)
c2
=
Tensor
([
14
],
mstype
.
int32
)
c2
=
Tensor
([
14
],
mstype
.
int32
)
c3
=
Tensor
([
1
],
mstype
.
int32
)
c3
=
Tensor
([
1
],
mstype
.
int32
)
...
@@ -149,6 +150,10 @@ def test_if_by_if():
...
@@ -149,6 +150,10 @@ def test_if_by_if():
assert
output
==
expect
assert
output
==
expect
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_if_in_if
():
def
test_if_in_if
():
output
=
if_in_if
(
c1
,
c2
,
c3
)
output
=
if_in_if
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
7
],
mstype
.
int32
)
expect
=
Tensor
([
7
],
mstype
.
int32
)
...
@@ -194,6 +199,7 @@ def test_while_by_while_in_while():
...
@@ -194,6 +199,7 @@ def test_while_by_while_in_while():
expect
=
Tensor
([
350
],
mstype
.
int32
)
expect
=
Tensor
([
350
],
mstype
.
int32
)
assert
output
==
expect
assert
output
==
expect
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_arm_ascend_training
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录