Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2ee4fdad
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ee4fdad
编写于
5月 15, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 15, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1165 new control sink entry
Merge pull request !1165 from zhoufeng/new-control-sink-entry
上级
bb374ebc
b78e54a5
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
71 addition
and
6 deletion
+71
-6
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+4
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+4
-0
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+51
-2
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+1
-1
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+1
-1
mindspore/ccsrc/session/session_basic.h
mindspore/ccsrc/session/session_basic.h
+2
-1
mindspore/ccsrc/vm/backend.cc
mindspore/ccsrc/vm/backend.cc
+4
-0
mindspore/ccsrc/vm/backend.h
mindspore/ccsrc/vm/backend.h
+4
-1
未找到文件。
mindspore/ccsrc/operator/ops.cc
浏览文件 @
2ee4fdad
...
...
@@ -78,6 +78,10 @@ const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
const
PrimitivePtr
kPrimRefToEmbed
=
std
::
make_shared
<
Primitive
>
(
"RefToEmbed"
);
const
PrimitivePtr
kPrimCreateInstance
=
std
::
make_shared
<
Primitive
>
(
"create_instance"
);
const
PrimitivePtr
kPrimLabelGoto
=
std
::
make_shared
<
Primitive
>
(
"LabelGoto"
);
const
PrimitivePtr
kPrimLabelSwitch
=
std
::
make_shared
<
Primitive
>
(
"LabelSwitch"
);
const
PrimitivePtr
kPrimLabelSet
=
std
::
make_shared
<
Primitive
>
(
"LabelSet"
);
// Structure
const
PrimitivePtr
kPrimStringEqual
=
std
::
make_shared
<
Primitive
>
(
"string_equal"
);
const
PrimitivePtr
kPrimStringConcat
=
std
::
make_shared
<
Primitive
>
(
"string_concat"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
2ee4fdad
...
...
@@ -84,6 +84,10 @@ extern const PrimitivePtr kPrimEmbed;
extern
const
PrimitivePtr
kPrimRefToEmbed
;
extern
const
PrimitivePtr
kPrimCreateInstance
;
extern
const
PrimitivePtr
kPrimLabelGoto
;
extern
const
PrimitivePtr
kPrimLabelSwitch
;
extern
const
PrimitivePtr
kPrimLabelSet
;
// Structure
extern
const
PrimitivePtr
kPrimStringEqual
;
extern
const
PrimitivePtr
kPrimStringConcat
;
...
...
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
2ee4fdad
...
...
@@ -269,13 +269,41 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa
bool
VmOptimizeAction
(
const
ResourcePtr
&
res
)
{
return
OptimizeAction
(
res
,
kVmPasses
);
}
static
bool
IsCtrlSink
()
{
auto
ms_ctx
=
MsContext
::
GetInstance
();
std
::
string
device_target
=
ms_ctx
->
device_target
();
if
(
device_target
!=
kAscendDevice
)
{
return
false
;
}
if
(
!
ms_ctx
->
enable_task_sink
())
{
return
false
;
}
char
*
enable_ctrl_sink
=
std
::
getenv
(
"ENABLE_CTRL_SINK"
);
if
(
enable_ctrl_sink
==
nullptr
)
{
return
false
;
}
std
::
string
enable_ctrl_sink_str
(
enable_ctrl_sink
);
if
(
enable_ctrl_sink_str
==
"0"
)
{
return
false
;
}
return
true
;
}
bool
TaskEmitAction
(
const
ResourcePtr
&
res
)
{
if
(
res
->
func_graph
()
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"TaskEmit args error"
;
}
FuncGraphPtr
func_graph
=
res
->
func_graph
();
auto
bc_ptr
=
res
->
results
()[
kBackend
].
cast
<
compile
::
BackendPtr
>
();
if
(
IsCtrlSink
())
{
res
->
results
()[
kOutput
]
=
bc_ptr
->
CompileGraph
(
NOT_NULL
(
func_graph
));
return
true
;
}
std
::
vector
<
PrimitivePtr
>
cut_list
=
compile
::
nonlinear_ops
;
if
(
bc_ptr
->
name
()
==
kMsConvert
)
{
cut_list
=
compile
::
GetMsNonlinearOps
();
...
...
@@ -286,10 +314,31 @@ bool TaskEmitAction(const ResourcePtr &res) {
}
bool
ExecuteAction
(
const
ResourcePtr
&
res
)
{
if
(
res
->
results
().
count
(
kOutput
)
==
0
||
!
res
->
results
()[
kOutput
].
is
<
compile
::
FinalVMPtr
>
()
)
{
if
(
res
->
results
().
count
(
kOutput
)
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Execute args error"
;
}
if
(
IsCtrlSink
())
{
if
(
!
res
->
results
()[
kOutput
].
is
<
GraphId
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Execute args error"
;
}
auto
graph_id
=
res
->
results
()[
kOutput
].
cast
<
GraphId
>
();
auto
bc_ptr
=
res
->
results
()[
kBackend
].
cast
<
std
::
shared_ptr
<
compile
::
MsBackend
>>
();
compile
::
VmEvalFuncPtr
run
=
std
::
make_shared
<
compile
::
VmEvalFunc
>
([
&
bc_ptr
,
graph_id
](
const
VectorRef
&
args
)
->
BaseRef
{
MS_LOG
(
INFO
)
<<
"Execute args size"
<<
args
.
size
();
auto
outs
=
bc_ptr
->
RunGraph
(
graph_id
,
args
);
MS_LOG
(
DEBUG
)
<<
"out size"
<<
outs
.
size
();
return
outs
[
0
];
});
res
->
results
()[
kOutput
]
=
run
;
return
true
;
}
if
(
!
res
->
results
()[
kOutput
].
is
<
compile
::
FinalVMPtr
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Execute args error"
;
}
compile
::
FinalVMPtr
vm
=
res
->
results
()[
kOutput
].
cast
<
compile
::
FinalVMPtr
>
();
if
(
vm
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"Call GE to Run the func_graph instead of VM"
;
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
2ee4fdad
...
...
@@ -138,7 +138,7 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
return
graph_id
;
}
GraphId
AscendSession
::
CompileGraph
(
const
FuncGraphPtr
&
func_graph
)
{
GraphId
AscendSession
::
CompileGraph
(
NotNull
<
FuncGraphPtr
>
func_graph
)
{
MS_LOG
(
INFO
)
<<
"start"
;
auto
graph
=
ConstructKernelGraph
(
func_graph
);
// split switch
...
...
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
2ee4fdad
...
...
@@ -42,7 +42,7 @@ class AscendSession : public SessionBasic {
context_
=
std
::
make_shared
<
Context
>
(
kAscendDevice
,
device_id
);
}
GraphId
CompileGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
override
;
GraphId
CompileGraph
(
const
FuncGraphPtr
&
func_graph
)
override
;
GraphId
CompileGraph
(
NotNull
<
FuncGraphPtr
>
func_graph
)
override
;
void
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
outputs
)
override
;
void
BuildGraph
(
GraphId
)
override
;
void
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
...
...
mindspore/ccsrc/session/session_basic.h
浏览文件 @
2ee4fdad
...
...
@@ -28,6 +28,7 @@
#include "ir/meta_tensor.h"
#include "utils/any.h"
#include "utils/base_ref.h"
#include "utils/contract.h"
#include "pynative/pynative_execute.h"
#include "device/kernel_info.h"
...
...
@@ -57,7 +58,7 @@ class SessionBasic {
virtual
~
SessionBasic
()
{
summary_callback_
=
nullptr
;
}
virtual
GraphId
CompileGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
=
0
;
virtual
GraphId
CompileGraph
(
const
FuncGraphPtr
&
)
{
return
kInvalidGraphId
;
}
virtual
GraphId
CompileGraph
(
NotNull
<
FuncGraphPtr
>
func_graph
)
{
return
kInvalidGraphId
;
}
// build graph, used to handle multiple child graphs
virtual
void
BuildGraph
(
GraphId
)
{}
...
...
mindspore/ccsrc/vm/backend.cc
浏览文件 @
2ee4fdad
...
...
@@ -327,5 +327,9 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_
sess_
->
RegisterSummaryCallBackFunc
(
callbacks
::
SummarySaveCallback
);
}
GraphId
MsBackend
::
CompileGraph
(
NotNull
<
FuncGraphPtr
>
fg
)
{
return
sess_
->
CompileGraph
(
fg
);
}
VectorRef
MsBackend
::
RunGraph
(
GraphId
graph_id
,
const
VectorRef
&
args
)
{
return
MsRunGraph
(
graph_id
,
args
);
}
}
// namespace compile
}
// namespace mindspore
mindspore/ccsrc/vm/backend.h
浏览文件 @
2ee4fdad
...
...
@@ -22,6 +22,7 @@
#include <unordered_map>
#include <utility>
#include "utils/contract.h"
#include "ir/anf.h"
#include "vm/segment_runner.h"
#include "vm/vm.h"
...
...
@@ -49,7 +50,7 @@ class Backend {
virtual
void
SetSwitchActive
(
const
BaseRef
&
,
bool
)
{}
virtual
void
RecallGraphInput
(
const
FuncGraphPtr
&
,
const
VectorRef
&
,
const
BaseRef
&
)
{}
virtual
void
SetGraphUserInputs
(
const
FuncGraphPtr
&
,
const
FuncGraphPtr
&
,
const
AnfNodePtrList
&
)
{}
virtual
GraphId
CompileGraph
(
NotNull
<
FuncGraphPtr
>
fg
)
{
return
kInvalidGraphId
;
}
void
set_curr_switch
(
const
BaseRef
&
value
)
{
curr_switch_
=
value
;
is_switch_call_
=
true
;
...
...
@@ -104,6 +105,8 @@ class MsBackend : public Backend {
void
Link
(
GraphId
)
override
;
AnfNodePtr
ConvertGraphInput
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
);
LinConvertResult
GetMultiGraphRun
(
const
FuncGraphPtr
&
g
)
override
;
GraphId
CompileGraph
(
NotNull
<
FuncGraphPtr
>
fg
)
override
;
VectorRef
RunGraph
(
GraphId
graph_id
,
const
VectorRef
&
args
);
private:
session
::
SessionPtr
sess_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录