Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ea65e61c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea65e61c
编写于
5月 16, 2020
作者:
W
wenchunjiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
adding new constructKernelGraph to transform all subgraphs to
kernel_graph from root func_graph
上级
92d196f0
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
202 addition
and
1 deletion
+202
-1
mindspore/ccsrc/session/anf_runtime_algorithm.cc
mindspore/ccsrc/session/anf_runtime_algorithm.cc
+10
-0
mindspore/ccsrc/session/anf_runtime_algorithm.h
mindspore/ccsrc/session/anf_runtime_algorithm.h
+1
-0
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+183
-1
mindspore/ccsrc/session/session_basic.h
mindspore/ccsrc/session/session_basic.h
+4
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+4
-0
未找到文件。
mindspore/ccsrc/session/anf_runtime_algorithm.cc
浏览文件 @
ea65e61c
...
...
@@ -909,5 +909,15 @@ bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
auto
kernel_name
=
AnfAlgo
::
GetCNodeName
(
node
);
return
kernel_name
==
kGetNextOpName
;
}
FuncGraphPtr
AnfRuntimeAlgorithm
::
GetValueNodeFuncGraph
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
auto
value
=
value_node
->
value
();
MS_EXCEPTION_IF_NULL
(
value
);
auto
func_graph
=
value
->
cast
<
FuncGraphPtr
>
();
return
func_graph
;
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/anf_runtime_algorithm.h
浏览文件 @
ea65e61c
...
...
@@ -181,6 +181,7 @@ class AnfRuntimeAlgorithm {
static
size_t
GetRealInputIndex
(
const
AnfNodePtr
&
anf_node
,
const
size_t
cur_index
);
static
bool
IsCommunicationOp
(
const
AnfNodePtr
&
node
);
static
bool
IsGetNext
(
const
NotNull
<
AnfNodePtr
>
&
node
);
static
FuncGraphPtr
GetValueNodeFuncGraph
(
const
AnfNodePtr
&
node
);
};
}
// namespace session
using
AnfAlgo
=
session
::
AnfRuntimeAlgorithm
;
...
...
mindspore/ccsrc/session/session_basic.cc
100755 → 100644
浏览文件 @
ea65e61c
...
...
@@ -451,6 +451,126 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
return
new_cnode
;
}
CNodePtr
SessionBasic
::
CreateNewCNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
graph
);
std
::
vector
<
AnfNodePtr
>
cnode_inputs
;
auto
attr_input
=
cnode
->
input
(
kAnfPrimitiveIndex
);
MS_EXCEPTION_IF_NULL
(
attr_input
);
if
(
IsValueNode
<
FuncGraph
>
(
attr_input
))
{
// create primitive of cnode:call
cnode_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kCallOpName
))};
// create a ValueNode<KernelGraph> as input of cnode:call
if
(
graph
->
GetBackendAnfByFrontAnf
(
attr_input
)
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
graph
->
GetBackendAnfByFrontAnf
(
attr_input
));
}
else
{
auto
new_value_node
=
CreateValueNodeKernelGraph
(
attr_input
,
graph
);
if
(
new_value_node
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
new_value_node
);
}
}
}
else
if
(
attr_input
->
isa
<
CNode
>
())
{
// create primitive of cnode:call(switch)
cnode_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kCallOpName
))};
if
(
graph
->
GetBackendAnfByFrontAnf
(
attr_input
)
!=
nullptr
)
{
auto
cnode_input
=
graph
->
GetBackendAnfByFrontAnf
(
attr_input
);
auto
prim
=
GetCNodePrimitive
(
cnode_input
);
MS_EXCEPTION_IF_NULL
(
prim
);
if
(
prim
->
name
()
!=
kSwitchOpName
)
{
MS_LOG
(
EXCEPTION
)
<<
"CNode input[0] must be switch."
;
}
cnode_inputs
.
emplace_back
(
cnode_input
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"CNode input[0] is CNode:"
<<
attr_input
->
DebugString
()
<<
", but input[0] has not been created."
;
}
}
else
{
// get primitive of old node
auto
prim
=
AnfAlgo
::
GetCNodePrimitive
(
cnode
);
MS_EXCEPTION_IF_NULL
(
prim
);
// push attr to inputs[0] of new cnode
cnode_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
*
prim
))};
}
for
(
size_t
input_idx
=
1
;
input_idx
<
cnode
->
inputs
().
size
();
input_idx
++
)
{
auto
anf
=
cnode
->
inputs
()[
input_idx
];
MS_EXCEPTION_IF_NULL
(
anf
);
// anf has been created before
if
(
graph
->
GetBackendAnfByFrontAnf
(
anf
)
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
graph
->
GetBackendAnfByFrontAnf
(
anf
));
continue
;
}
else
if
(
anf
->
isa
<
ValueNode
>
())
{
if
(
!
IsValueNode
<
FuncGraph
>
(
anf
))
{
// if input is a common value node,
auto
new_value_node
=
CreateNewValueNode
(
anf
,
graph
);
if
(
new_value_node
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
new_value_node
);
}
}
else
{
// if input is a ValueNode<FuncGraph>
auto
new_value_node
=
CreateValueNodeKernelGraph
(
anf
,
graph
);
if
(
new_value_node
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
new_value_node
);
}
}
continue
;
}
else
if
(
anf
->
isa
<
Parameter
>
())
{
auto
new_parameter
=
CreateNewParameter
(
anf
,
graph
);
cnode_inputs
.
push_back
(
new_parameter
);
continue
;
}
MS_LOG
(
EXCEPTION
)
<<
"Unexpected input["
<<
anf
->
DebugString
()
<<
"]"
;
}
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceCopy
>
(
cnode
->
debug_info
()));
auto
new_cnode
=
graph
->
NewCNode
(
cnode_inputs
);
TraceManager
::
EndTrace
();
return
new_cnode
;
}
ValueNodePtr
SessionBasic
::
CreateValueNodeKernelGraph
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
anf
);
auto
value_node
=
anf
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
auto
sub_func_graph
=
AnfAlgo
::
GetValueNodeFuncGraph
(
anf
);
MS_EXCEPTION_IF_NULL
(
sub_func_graph
);
if
(
front_backend_graph_map_
.
find
(
sub_func_graph
)
==
front_backend_graph_map_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"FuncGraph: "
<<
sub_func_graph
->
ToString
()
<<
" has not been transformed to KernelGraph."
;
}
auto
sub_kernel_graph
=
front_backend_graph_map_
[
sub_func_graph
];
ValueNodePtr
new_value_node
=
std
::
make_shared
<
ValueNode
>
(
sub_kernel_graph
);
new_value_node
->
set_abstract
(
value_node
->
abstract
());
// create new kernel_info of new value_node
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
kernel_info
->
SetFeatureMapFlag
(
false
);
new_value_node
->
set_kernel_info
(
kernel_info
);
// create kernel_build_info for new value node
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
new_value_node
.
get
());
AnfAlgo
::
SetGraphId
(
graph
->
graph_id
(),
new_value_node
.
get
());
graph
->
FrontBackendlMapAdd
(
anf
,
new_value_node
);
graph
->
AddValueNodeToGraph
(
new_value_node
);
return
new_value_node
;
}
ParameterPtr
SessionBasic
::
CreateNewParameter
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
anf
);
if
(
!
anf
->
isa
<
Parameter
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"anf["
<<
anf
->
DebugString
()
<<
"] is not a parameter"
;
}
auto
graph_inputs
=
graph
->
MutableInputs
();
MS_EXCEPTION_IF_NULL
(
graph_inputs
);
auto
new_parameter
=
graph
->
NewParameter
(
anf
->
cast
<
ParameterPtr
>
());
graph_inputs
->
push_back
(
new_parameter
);
graph
->
FrontBackendlMapAdd
(
anf
,
new_parameter
);
return
new_parameter
;
}
KernelGraphPtr
SessionBasic
::
ConstructKernelGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
{
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
other_graph_cnode
;
auto
graph
=
NewKernelGraph
();
...
...
@@ -494,7 +614,69 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
return
graph
;
}
std
::
shared_ptr
<
KernelGraph
>
SessionBasic
::
ConstructKernelGraph
(
const
FuncGraphPtr
&
)
{
return
nullptr
;
}
std
::
shared_ptr
<
KernelGraph
>
SessionBasic
::
ConstructKernelGraph
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
if
(
front_backend_graph_map_
.
find
(
func_graph
)
!=
front_backend_graph_map_
.
end
())
{
MS_LOG
(
INFO
)
<<
"FuncGraph: "
<<
func_graph
->
ToString
()
<<
" has been transformed to KernelGraph."
;
return
front_backend_graph_map_
[
func_graph
];
}
auto
node_list
=
TopoSort
(
func_graph
->
get_return
());
auto
graph
=
NewKernelGraph
();
front_backend_graph_map_
[
func_graph
]
=
graph
;
MS_LOG
(
INFO
)
<<
"Create graph: "
<<
graph
->
graph_id
();
for
(
const
auto
&
node
:
node_list
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_LOG
(
DEBUG
)
<<
"Start create new cnode, node = "
<<
node
->
DebugString
();
if
(
!
node
->
isa
<
CNode
>
())
{
MS_LOG
(
DEBUG
)
<<
"Node "
<<
node
->
DebugString
()
<<
" is not CNode"
;
continue
;
}
else
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
// recurse control ops: call, partial
auto
attr_input
=
cnode
->
input
(
kAnfPrimitiveIndex
);
MS_EXCEPTION_IF_NULL
(
attr_input
);
if
(
IsValueNode
<
FuncGraph
>
(
attr_input
))
{
// recurse call subgraph
auto
sub_func_graph
=
AnfAlgo
::
GetValueNodeFuncGraph
(
attr_input
);
ConstructKernelGraph
(
sub_func_graph
);
}
else
if
(
IsValueNode
<
Primitive
>
(
attr_input
))
{
auto
prim
=
GetCNodePrimitive
(
node
);
MS_EXCEPTION_IF_NULL
(
prim
);
if
(
prim
->
name
()
==
kPartialOpName
)
{
// recurse partial subgraph
auto
func_graph_node
=
cnode
->
input
(
kAnfPartialFuncGraphIndex
);
MS_EXCEPTION_IF_NULL
(
func_graph_node
);
auto
sub_func_graph
=
AnfAlgo
::
GetValueNodeFuncGraph
(
func_graph_node
);
ConstructKernelGraph
(
sub_func_graph
);
}
}
// create a new cnode object
auto
new_cnode
=
CreateNewCNode
(
cnode
,
graph
.
get
());
MS_EXCEPTION_IF_NULL
(
new_cnode
);
new_cnode
->
set_abstract
(
cnode
->
abstract
());
new_cnode
->
set_scope
(
cnode
->
scope
());
graph
->
FrontBackendlMapAdd
(
node
,
new_cnode
);
// set original return to kernel_graph
if
(
IsPrimitive
(
new_cnode
->
input
(
kAnfPrimitiveIndex
),
prim
::
kPrimReturn
))
{
graph
->
set_return
(
new_cnode
);
}
}
}
MS_EXCEPTION_IF_NULL
(
context_
);
FuncGraphManagerPtr
manager
=
context_
->
manager
();
if
(
manager
)
{
manager
->
AddFuncGraph
(
graph
);
graph
->
set_manager
(
manager
);
}
graph
->
SetExecOrderByDefault
();
return
graph
;
}
// run graph steps
void
SessionBasic
::
LoadInputData
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
,
...
...
mindspore/ccsrc/session/session_basic.h
浏览文件 @
ea65e61c
...
...
@@ -78,6 +78,7 @@ class SessionBasic {
CNodePtr
CreateNewCNode
(
const
CNodePtr
&
cnode
,
bool
valid_input
,
KernelGraph
*
graph
,
bool
*
from_other_graph
,
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
*
other_graph_cnode
);
CNodePtr
CreateNewCNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
);
// set parameters of final graph
virtual
GraphId
SetFinalGraphInput
(
const
std
::
vector
<
AnfNodePtr
>
&
)
{
return
kInvalidGraphId
;
}
...
...
@@ -111,9 +112,12 @@ class SessionBasic {
// create a new kernel graph and update the graph sum
KernelGraphPtr
NewKernelGraph
();
ParameterPtr
CreateNewParameterFromParameter
(
const
AnfNodePtr
&
anf
,
bool
valid_input
,
KernelGraph
*
graph
);
ValueNodePtr
CreateValueNodeKernelGraph
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
);
ParameterPtr
CreateNewParameter
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
);
std
::
unordered_map
<
GraphId
,
std
::
shared_ptr
<
KernelGraph
>>
graphs_
;
std
::
unordered_map
<
GraphInfo
,
std
::
shared_ptr
<
KernelGraph
>>
run_op_graphs_
;
std
::
unordered_map
<
FuncGraphPtr
,
KernelGraphPtr
>
front_backend_graph_map_
;
std
::
shared_ptr
<
Context
>
context_
;
CallBackFunc
summary_callback_
;
static
GraphId
graph_sum_
;
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
ea65e61c
...
...
@@ -141,6 +141,9 @@ constexpr auto kLabelSetOpName = "LabelSet";
constexpr
auto
kLabelSwitchOpName
=
"LabelSwitch"
;
constexpr
auto
kLabelGotoOpName
=
"LabelGoto"
;
constexpr
auto
kBNInferGradOpName
=
"BNInferGrad"
;
constexpr
auto
kCallOpName
=
"call"
;
constexpr
auto
kPartialOpName
=
"partial"
;
constexpr
auto
kSwitchOpName
=
"switch"
;
// attr key name
constexpr
auto
kAttrInputNames
=
"input_names"
;
...
...
@@ -196,6 +199,7 @@ const size_t kMemAlignSize = 512;
// define special index in special node
constexpr
auto
kAnfPrimitiveIndex
=
0
;
constexpr
auto
kAnfPartialFuncGraphIndex
=
1
;
constexpr
auto
kRealInputNodeIndexInTupleGetItem
=
1
;
constexpr
auto
kInputNodeOutputIndexInTupleGetItem
=
2
;
constexpr
auto
kTupleGetItemInputSize
=
3
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录