Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
130cc296
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
130cc296
编写于
7月 15, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 15, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2931 Ascend control flow not split graphs
Merge pull request !2931 from zhoufeng/liantiao1
上级
46284661
439d6d61
变更
11
展开全部
隐藏空白更改
内联
并排
Showing
11 changed file
with
762 addition
and
267 deletion
+762
-267
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
+69
-39
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
+5
-1
mindspore/ccsrc/backend/session/ascend_control_parser.cc
mindspore/ccsrc/backend/session/ascend_control_parser.cc
+365
-179
mindspore/ccsrc/backend/session/ascend_control_parser.h
mindspore/ccsrc/backend/session/ascend_control_parser.h
+26
-5
mindspore/ccsrc/backend/session/ascend_session.cc
mindspore/ccsrc/backend/session/ascend_session.cc
+246
-7
mindspore/ccsrc/backend/session/ascend_session.h
mindspore/ccsrc/backend/session/ascend_session.h
+9
-0
mindspore/ccsrc/backend/session/kernel_graph.cc
mindspore/ccsrc/backend/session/kernel_graph.cc
+28
-3
mindspore/ccsrc/backend/session/kernel_graph.h
mindspore/ccsrc/backend/session/kernel_graph.h
+7
-0
mindspore/ccsrc/backend/session/session_basic.cc
mindspore/ccsrc/backend/session/session_basic.cc
+5
-33
mindspore/ccsrc/runtime/device/kernel_runtime.cc
mindspore/ccsrc/runtime/device/kernel_runtime.cc
+1
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+1
-0
未找到文件。
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
浏览文件 @
130cc296
...
...
@@ -40,6 +40,9 @@ using kernel::KernelBuildInfoPtr;
using
kernel
::
KernelMod
;
using
kernel
::
KernelModPtr
;
namespace
{
constexpr
size_t
kNopNodeInputSize
=
2
;
constexpr
size_t
kNopNodeRealInputIndex
=
1
;
std
::
vector
<
size_t
>
TransShapeToSizet
(
const
abstract
::
ShapePtr
&
shape
)
{
MS_EXCEPTION_IF_NULL
(
shape
);
std
::
vector
<
size_t
>
shape_size_t
;
...
...
@@ -48,6 +51,26 @@ std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
}
}
// namespace
AnfNodePtr
AnfRuntimeAlgorithm
::
GetTupleGetItemRealInput
(
const
CNodePtr
&
tuple_get_item
)
{
MS_EXCEPTION_IF_NULL
(
tuple_get_item
);
if
(
tuple_get_item
->
size
()
!=
kTupleGetItemInputSize
)
{
MS_LOG
(
EXCEPTION
)
<<
"The node tuple_get_item must have 2 inputs!"
;
}
return
tuple_get_item
->
input
(
kRealInputNodeIndexInTupleGetItem
);
}
size_t
AnfRuntimeAlgorithm
::
GetTupleGetItemOutIndex
(
const
CNodePtr
&
tuple_get_item
)
{
MS_EXCEPTION_IF_NULL
(
tuple_get_item
);
if
(
tuple_get_item
->
size
()
!=
kTupleGetItemInputSize
)
{
MS_LOG
(
EXCEPTION
)
<<
"The node tuple_get_item must have 2 inputs!"
;
}
auto
output_index_value_node
=
tuple_get_item
->
input
(
kInputNodeOutputIndexInTupleGetItem
);
MS_EXCEPTION_IF_NULL
(
output_index_value_node
);
auto
value_node
=
output_index_value_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
return
IntToSize
(
GetValue
<
int
>
(
value_node
->
value
()));
}
KernelWithIndex
AnfRuntimeAlgorithm
::
VisitKernel
(
const
AnfNodePtr
&
anf_node
,
size_t
index
)
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
if
(
anf_node
->
isa
<
ValueNode
>
())
{
...
...
@@ -83,49 +106,47 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
}
}
KernelWithIndex
AnfRuntimeAlgorithm
::
VisitKernelWithReturnType
(
const
AnfNodePtr
&
anf_node
,
size_
t
index
,
KernelWithIndex
AnfRuntimeAlgorithm
::
VisitKernelWithReturnType
(
const
AnfNodePtr
&
anf_node
,
in
t
index
,
bool
visit_nop_node
,
const
std
::
vector
<
PrimitivePtr
>
&
return_types
)
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
for
(
const
auto
&
prim_type
:
return_types
)
{
if
(
CheckPrimitiveType
(
anf_node
,
prim_type
))
{
return
std
::
make_pair
(
anf_node
,
index
);
}
if
(
std
::
any_of
(
return_types
.
begin
(),
return_types
.
end
(),
[
&
anf_node
](
const
PrimitivePtr
&
prim_type
)
->
bool
{
return
CheckPrimitiveType
(
anf_node
,
prim_type
);
}))
{
return
KernelWithIndex
(
anf_node
,
index
);
}
if
(
anf_node
->
isa
<
ValueNode
>
())
{
return
std
::
make_pair
(
anf_node
,
0
);
}
else
if
(
anf_node
->
isa
<
Parameter
>
())
{
return
std
::
make_pair
(
anf_node
,
0
);
}
else
if
(
anf_node
->
isa
<
CNode
>
())
{
auto
cnode
=
anf_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
input0
=
cnode
->
input
(
0
);
MS_EXCEPTION_IF_NULL
(
input0
);
if
(
IsPrimitive
(
input0
,
prim
::
kPrimTupleGetItem
))
{
if
(
cnode
->
inputs
().
size
()
!=
kTupleGetItemInputSize
)
{
MS_LOG
(
EXCEPTION
)
<<
"The node tuple_get_item must have 2 inputs!"
;
}
auto
input2
=
cnode
->
input
(
kInputNodeOutputIndexInTupleGetItem
);
MS_EXCEPTION_IF_NULL
(
input2
);
auto
value_node
=
input2
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
int
item_idx
=
GetValue
<
int
>
(
value_node
->
value
());
return
VisitKernelWithReturnType
(
cnode
->
input
(
kRealInputNodeIndexInTupleGetItem
),
IntToSize
(
item_idx
),
visit_nop_node
,
return_types
);
}
else
if
(
IsPrimitive
(
input0
,
prim
::
kPrimDepend
)
||
IsPrimitive
(
input0
,
prim
::
kPrimControlDepend
))
{
return
VisitKernelWithReturnType
(
cnode
->
input
(
kRealInputIndexInDepend
),
0
,
visit_nop_node
,
return_types
);
}
else
if
(
opt
::
IsNopNode
(
cnode
)
&&
visit_nop_node
)
{
if
(
cnode
->
inputs
().
size
()
==
2
)
{
return
VisitKernelWithReturnType
(
cnode
->
input
(
1
),
0
,
visit_nop_node
,
return_types
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
cnode
->
DebugString
()
<<
"Invalid nop node"
;
if
(
!
anf_node
->
isa
<
CNode
>
())
{
return
KernelWithIndex
(
anf_node
,
0
);
}
auto
cnode
=
anf_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
CheckPrimitiveType
(
cnode
,
prim
::
kPrimTupleGetItem
))
{
auto
item_with_index_tmp
=
VisitKernelWithReturnType
(
GetTupleGetItemRealInput
(
cnode
),
GetTupleGetItemOutIndex
(
cnode
),
visit_nop_node
,
return_types
);
if
(
CheckPrimitiveType
(
item_with_index_tmp
.
first
,
prim
::
kPrimMakeTuple
))
{
MS_EXCEPTION_IF_NULL
(
item_with_index_tmp
.
first
);
auto
make_tuple
=
item_with_index_tmp
.
first
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
make_tuple
);
const
std
::
vector
<
AnfNodePtr
>
&
make_tuple_inputs
=
make_tuple
->
inputs
();
size_t
make_tuple_input_index
=
item_with_index_tmp
.
second
+
1
;
if
(
make_tuple_input_index
>=
make_tuple_inputs
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Index["
<<
make_tuple_input_index
<<
"] out of range["
<<
make_tuple_inputs
.
size
()
<<
"]."
;
}
}
else
{
return
std
::
make_pair
(
anf_node
,
index
);
return
VisitKernelWithReturnType
(
make_tuple_inputs
[
make_tuple_input_index
],
0
,
visit_nop_node
,
return_types
);
}
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"The input is invalid"
;
return
item_with_index_tmp
;
}
if
(
CheckPrimitiveType
(
cnode
,
prim
::
kPrimDepend
)
||
CheckPrimitiveType
(
cnode
,
prim
::
kPrimControlDepend
))
{
return
VisitKernelWithReturnType
(
cnode
->
input
(
kRealInputIndexInDepend
),
index
,
visit_nop_node
,
return_types
);
}
if
(
opt
::
IsNopNode
(
cnode
)
&&
visit_nop_node
)
{
if
(
cnode
->
size
()
!=
kNopNodeInputSize
)
{
MS_LOG
(
EXCEPTION
)
<<
"Invalid nop node "
<<
cnode
->
DebugString
();
}
return
VisitKernelWithReturnType
(
cnode
->
input
(
kNopNodeRealInputIndex
),
0
,
visit_nop_node
,
return_types
);
}
return
KernelWithIndex
(
anf_node
,
index
);
}
std
::
vector
<
AnfNodePtr
>
AnfRuntimeAlgorithm
::
GetAllOutput
(
const
AnfNodePtr
&
node
,
...
...
@@ -591,7 +612,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
if
(
opt
::
IsNopNode
(
node
)
&&
visit_nop_node
)
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
cnode
->
inputs
().
size
()
==
2
)
{
if
(
cnode
->
size
()
==
kNopNodeInputSize
)
{
return
AnfRuntimeAlgorithm
::
GetPrevNodeOutputAddr
(
cnode
,
0
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
node
->
DebugString
()
<<
"Invalid nop node"
;
...
...
@@ -613,7 +634,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
if
(
opt
::
IsNopNode
(
node
)
&&
visit_nop_node
)
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
cnode
->
inputs
().
size
()
==
2
)
{
if
(
cnode
->
inputs
().
size
()
==
kNopNodeInputSize
)
{
return
AnfRuntimeAlgorithm
::
GetPrevNodeMutableOutputAddr
(
cnode
,
0
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
node
->
DebugString
()
<<
"Invalid nop node."
;
...
...
@@ -806,7 +827,7 @@ bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
IsPrimitive
(
input
,
prim
::
kPrimHistogramSummary
)
||
IsPrimitive
(
input
,
prim
::
kPrimMakeTuple
)
||
IsPrimitive
(
input
,
prim
::
kPrimStateSetItem
)
||
IsPrimitive
(
input
,
prim
::
kPrimDepend
)
||
IsPrimitive
(
input
,
prim
::
kPrimTupleGetItem
)
||
IsPrimitive
(
input
,
prim
::
kPrimControlDepend
)
||
IsPrimitive
(
input
,
prim
::
kPrimReturn
);
IsPrimitive
(
input
,
prim
::
kPrimReturn
)
||
IsPrimitive
(
input
,
prim
::
kPrimPartial
)
;
return
!
is_virtual_node
;
}
...
...
@@ -1117,5 +1138,14 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s
}
return
GetCNodeOutputPrecision
(
kernel_with_index
.
first
);
}
bool
AnfRuntimeAlgorithm
::
IsCondControlKernel
(
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
->
inputs
().
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Illegal null input of cnode."
;
}
auto
input
=
node
->
input
(
kAnfPrimitiveIndex
);
return
IsPrimitive
(
input
,
prim
::
kPrimLabelGoto
)
||
IsPrimitive
(
input
,
prim
::
kPrimLabelSwitch
);
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
浏览文件 @
130cc296
...
...
@@ -42,9 +42,12 @@ using DeviceAddress = device::DeviceAddress;
using
DeviceAddressPtr
=
device
::
DeviceAddressPtr
;
class
AnfRuntimeAlgorithm
{
public:
// get real input node of tuple_get_item
static
AnfNodePtr
GetTupleGetItemRealInput
(
const
CNodePtr
&
tuple_get_item
);
static
size_t
GetTupleGetItemOutIndex
(
const
CNodePtr
&
tuple_get_item
);
// get input_anf_node's real kernel by recurse
static
KernelWithIndex
VisitKernel
(
const
AnfNodePtr
&
input_anf_node
,
size_t
output_index
);
static
KernelWithIndex
VisitKernelWithReturnType
(
const
AnfNodePtr
&
input_anf_node
,
size_
t
output_index
,
static
KernelWithIndex
VisitKernelWithReturnType
(
const
AnfNodePtr
&
input_anf_node
,
in
t
output_index
,
bool
visit_nop_node
=
false
,
const
std
::
vector
<
PrimitivePtr
>
&
return_types
=
{
prim
::
kPrimMakeTuple
});
...
...
@@ -205,6 +208,7 @@ class AnfRuntimeAlgorithm {
static
TypeId
GetCNodeOutputPrecision
(
const
AnfNodePtr
&
node
);
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.
static
TypeId
GetPrevNodeOutputPrecision
(
const
AnfNodePtr
&
node
,
size_t
input_idx
);
static
bool
IsCondControlKernel
(
const
CNodePtr
&
node
);
};
}
// namespace session
using
AnfAlgo
=
session
::
AnfRuntimeAlgorithm
;
...
...
mindspore/ccsrc/backend/session/ascend_control_parser.cc
浏览文件 @
130cc296
此差异已折叠。
点击以展开。
mindspore/ccsrc/backend/session/ascend_control_parser.h
浏览文件 @
130cc296
...
...
@@ -20,6 +20,8 @@
#include <map>
#include <vector>
#include <tuple>
#include <utility>
#include <functional>
#include "backend/session/kernel_graph.h"
#include "utils/base_ref.h"
#include "utils/contract.h"
...
...
@@ -29,16 +31,23 @@ namespace mindspore {
namespace
session
{
class
AscendControlParser
{
public:
static
void
ChildGraphDataAssign
(
const
std
::
map
<
uint32_t
,
KernelGraphPtr
>
&
graph_id_map
);
static
void
LinkGraph
(
NotNull
<
KernelGraphPtr
>
kg
);
static
void
InsertDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
attch_node
);
static
void
InsertControlDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
first_node
,
NotNull
<
AnfNodePtr
>
second_node
);
static
void
ExecutorValidate
(
NotNull
<
KernelGraphPtr
>
root_graph
);
static
void
UpdateChildGraphOrder
(
NotNull
<
KernelGraphPtr
>
kg
);
static
void
InsertMultipleAssignToGraph
(
NotNull
<
KernelGraphPtr
>
from_graph
,
const
AnfNodePtr
&
jump_node
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
private:
class
ReferenceCounter
;
static
void
EraseParameter
(
NotNull
<
KernelGraphPtr
>
root_graph
,
const
std
::
set
<
KernelGraphPtr
>
&
graph_list
);
static
void
EraseLabel
(
NotNull
<
KernelGraphPtr
>
root_graph
);
static
void
ChildGraphDataAssign
(
NotNull
<
KernelGraphPtr
>
kg
,
const
NotNull
<
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
AnfNodePtr
>>
*>
link_list
,
const
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
NotNull
<
CNodePtr
>
GetStartLabel
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
last_node
,
const
CNodePtr
&
last_label
);
static
NotNull
<
CNodePtr
>
ProcessKernelGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
last_node
,
...
...
@@ -53,11 +62,10 @@ class AscendControlParser {
static
void
LinkParentGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
from_graph_call_node
,
const
CNodePtr
&
last_label
);
static
KernelGraphPtr
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
);
static
void
InsertMultipleAssignToGraph
(
NotNull
<
KernelGraphPtr
>
from_graph
,
NotNull
<
KernelGraphPtr
>
to_graph
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
static
AnfNodePtr
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
static
std
::
vector
<
std
::
pair
<
KernelGraphPtr
,
std
::
vector
<
AnfNodePtr
>>>
ParseCallNode
(
NotNull
<
CNodePtr
>
call_node
);
static
std
::
tuple
<
KernelGraphPtr
,
std
::
vector
<
AnfNodePtr
>>
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
);
// root graph order
static
bool
CheckLabelIndex
(
uint32_t
order_index
,
uint32_t
label_index
,
const
CNodePtr
&
cnode
,
...
...
@@ -65,6 +73,19 @@ class AscendControlParser {
static
std
::
vector
<
CNodePtr
>
RecurseGraph
(
NotNull
<
KernelGraphPtr
>
graph
,
const
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
};
class
AscendControlParser
::
ReferenceCounter
{
public:
explicit
ReferenceCounter
(
std
::
function
<
bool
(
int32_t
,
int32_t
)
>
func
)
:
predicate_
(
func
),
count_
()
{}
void
AddReadCount
(
const
AnfNodePtr
&
key
,
int32_t
num
);
void
AddWriteCount
(
const
AnfNodePtr
&
key
,
int32_t
num
);
void
EraseElem
(
const
AnfNodePtr
&
key
);
bool
HasValidElem
()
const
;
std
::
tuple
<
AnfNodePtr
,
int32_t
,
int32_t
>
GetOneValidElem
()
const
;
private:
std
::
function
<
bool
(
int32_t
,
int32_t
)
>
predicate_
;
std
::
map
<
AnfNodePtr
,
std
::
pair
<
int32_t
,
int32_t
>>
count_
;
};
}
// namespace session
}
// namespace mindspore
...
...
mindspore/ccsrc/backend/session/ascend_session.cc
浏览文件 @
130cc296
...
...
@@ -289,6 +289,17 @@ static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
// this action should from bottom to top
graph
->
UpdateCallRealInput
();
}
void
InsertMakeTupleForOutput
(
NotNull
<
KernelGraphPtr
>
root_graph
)
{
auto
return_node
=
root_graph
->
get_return
();
MS_EXCEPTION_IF_NULL
(
return_node
);
if
(
return_node
->
size
()
<=
kReturnDataIndex
)
{
return
;
}
auto
make_tuple
=
root_graph
->
NewCNode
(
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimMakeTuple
->
name
())),
root_graph
->
output
()});
root_graph
->
set_output
(
make_tuple
);
}
}
// namespace
GraphId
AscendSession
::
CompileGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
{
...
...
@@ -305,22 +316,39 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
std
::
vector
<
KernelGraphPtr
>
all_graphs
;
auto
root_graph
=
ConstructKernelGraph
(
func_graph
,
&
all_graphs
);
BackendOptimization
(
all_graphs
);
// split switch
SplitGraphs
(
NOT_NULL
(
root_graph
));
// empty graph dont entry to backend
if
(
root_graph
->
execution_order
().
empty
())
{
MS_LOG
(
INFO
)
<<
root_graph
->
ToString
()
<<
" is empty graph."
;
InsertMakeTupleForOutput
(
NOT_NULL
(
root_graph
));
root_graph
->
set_executable
(
false
);
InitRuntimeResource
();
return
root_graph
->
graph_id
();
}
// create parameter for multiple branch
std
::
set
<
KernelGraphPtr
>
memo
;
CreateMultiBranchOutput
(
NOT_NULL
(
root_graph
),
NOT_NULL
(
&
memo
));
memo
.
clear
();
// insert goto labels and label_sets
LinkChildGraphs
(
NOT_NULL
(
root_graph
));
// resource initialize
InitRuntimeResource
();
// recurse compile child root_graph
std
::
set
<
KernelGraphPtr
>
memo
;
RecurseCompileGraph
(
NOT_NULL
(
root_graph
),
NOT_NULL
(
&
memo
));
IrFusionPass
(
NOT_NULL
(
root_graph
),
NOT_NULL
(
&
memo
));
memo
.
clear
();
SelectKernel
(
NOT_NULL
(
root_graph
));
memo
.
clear
();
HardwareOptimize
(
NOT_NULL
(
root_graph
),
NOT_NULL
(
&
memo
));
memo
.
clear
();
AssignStaticMemory
(
NOT_NULL
(
root_graph
),
NOT_NULL
(
&
memo
));
memo
.
clear
();
UpdateRefOutputMap
(
NOT_NULL
(
root_graph
),
NOT_NULL
(
&
memo
));
memo
.
clear
();
// add make_tuple to the output graph
InsertMakeTupleForOutput
(
NOT_NULL
(
root_graph
));
// root root_graph valiate,include genearte execute order and so on
RootGraphExecutorValidate
(
NOT_NULL
(
root_graph
));
// adjust kernel
...
...
@@ -1682,7 +1710,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
bool
split_flag
=
false
;
auto
apply_list
=
GetCNodes
(
TopoSort
(
graph
->
get_return
()));
// update the root graph child graph order
AscendControlParser
::
UpdateChildGraphOrder
(
graph
);
graph
->
UpdateChildGraphOrder
(
);
// get child list from current graph
std
::
vector
<
std
::
vector
<
CNodePtr
>>
child_graph_lists
=
GetChildList
(
apply_list
,
cut_prims
);
if
(
child_graph_lists
.
size
()
>
1
)
{
...
...
@@ -1714,7 +1742,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
}
split_flag
=
true
;
}
AscendControlParser
::
UpdateChildGraphOrder
(
graph
);
graph
->
UpdateChildGraphOrder
(
);
UpdateRealInput
(
graph
,
split_flag
,
memo
);
MS_LOG
(
INFO
)
<<
"Split graph["
<<
graph
->
graph_id
()
<<
"] end"
;
}
...
...
@@ -1753,5 +1781,216 @@ void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const Not
}
}
}
void
AscendSession
::
CreateMultiBranchOutput
(
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
if
(
memo
->
find
(
graph
.
get
())
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
graph
.
get
());
graph
->
UpdateChildGraphOrder
();
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
CreateMultiBranchOutput
(
NOT_NULL
(
child_graph
),
memo
);
}
std
::
map
<
AnfNodePtr
,
AnfNodePtr
>
need_replace_list
;
auto
node_list
=
GetCNodes
(
TopoSort
(
graph
->
get_return
()));
for
(
auto
&
node
:
node_list
)
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimCall
))
{
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
// auto multi_output_param = graph->NewParameter();
auto
origin_inputs
=
graph
->
inputs
();
auto
output_param
=
CreateNewParameterFromCNode
(
node
,
true
,
graph
.
get
().
get
());
MS_EXCEPTION_IF_NULL
(
graph
->
MutableInputs
());
graph
->
MutableInputs
()
->
operator
=
(
origin_inputs
);
graph
->
AddChildGraphResult
(
output_param
);
std
::
vector
<
AnfNodePtr
>
depend_inputs
=
{
graph
->
NewValueNode
(
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimDepend
->
name
()))),
output_param
,
node
};
auto
depend
=
graph
->
NewCNode
(
depend_inputs
);
need_replace_list
.
emplace
(
node
,
depend
);
MS_LOG
(
INFO
)
<<
"Create parameter "
<<
output_param
->
DebugString
()
<<
" for call node "
<<
node
->
DebugString
()
<<
", depend node is "
<<
depend
->
DebugString
();
// insert assign in order to transfer child graph output to parameter
auto
child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
node
);
for
(
auto
&
child_graph
:
child_graphs
)
{
MS_EXCEPTION_IF_NULL
(
child_graph
);
if
(
child_graph
->
get_output_null
())
{
continue
;
}
auto
graph_output
=
child_graph
->
output
();
AscendControlParser
::
InsertMultipleAssignToGraph
(
NOT_NULL
(
child_graph
),
nullptr
,
NOT_NULL
(
graph_output
),
NOT_NULL
(
output_param
));
}
}
}
// searching for nodes' input to replace call by depend(parameter, call)
for
(
auto
&
node
:
node_list
)
{
for
(
size_t
i
=
0
;
i
<
node
->
size
();
++
i
)
{
auto
input
=
node
->
input
(
i
);
auto
iter
=
need_replace_list
.
find
(
input
);
if
(
iter
!=
need_replace_list
.
end
())
{
node
->
set_input
(
i
,
iter
->
second
);
}
}
}
}
void
AscendSession
::
IrFusionPass
(
const
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
if
(
memo
->
find
(
graph
)
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
graph
.
get
());
opt
::
AscendBackendIRFusionOptimization
(
graph
);
opt
::
AscendBackendFuseBasicOpt
(
graph
,
true
);
opt
::
AscendBackendGraphKernelOpt
(
graph
,
true
);
graph
->
SetExecOrderByDefault
();
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
();
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
();
if
(
save_graphs
)
{
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"select_kernel_before"
+
"_graph_"
+
std
::
to_string
(
graph
->
graph_id
())
+
".ir"
;
DumpIR
(
file_path
,
graph
.
get
());
}
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
IrFusionPass
(
NOT_NULL
(
child_graph
),
memo
);
}
}
void
AscendSession
::
SelectKernel
(
NotNull
<
KernelGraphPtr
>
root_graph
)
{
MS_LOG
(
INFO
)
<<
"Start select kernel."
;
size_t
raise_precision_count
=
0
;
size_t
reduce_precision_count
=
0
;
std
::
set
<
KernelGraphPtr
>
memo
;
(
void
)
RecurseSelectKernelInfo
(
root_graph
,
NOT_NULL
(
&
memo
),
&
raise_precision_count
,
&
reduce_precision_count
);
memo
.
clear
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
()
==
kGraphMode
)
{
if
(
raise_precision_count
>
0
)
{
MS_LOG
(
WARNING
)
<<
"There has "
<<
raise_precision_count
<<
" node/nodes used raise precision to selected the kernel!"
;
}
if
(
reduce_precision_count
>
0
)
{
MS_LOG
(
WARNING
)
<<
"There has "
<<
raise_precision_count
<<
" node/nodes used reduce precision to selected the kernel!"
;
}
}
MS_LOG
(
INFO
)
<<
"Finish!"
;
}
void
AscendSession
::
RecurseSelectKernelInfo
(
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
const
memo
,
size_t
*
const
raise_precision_count
,
size_t
*
const
reduce_precision_count
)
const
{
if
(
memo
->
find
(
graph
)
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
graph
.
get
());
MS_LOG
(
INFO
)
<<
"Start to select kernel info in graph: "
<<
graph
->
graph_id
();
for
(
const
auto
&
cnode
:
graph
->
execution_order
())
{
if
(
AnfAlgo
::
IsCondControlKernel
(
cnode
))
{
std
::
vector
<
KernelGraphPtr
>
child_graphs
;
if
(
AnfAlgo
::
HasNodeAttr
(
kAttrChildGraph
,
cnode
))
{
child_graphs
=
AnfAlgo
::
GetNodeAttr
<
std
::
vector
<
KernelGraphPtr
>>
(
cnode
,
kAttrChildGraph
);
}
for
(
auto
&
child_graph
:
child_graphs
)
{
RecurseSelectKernelInfo
(
NOT_NULL
(
child_graph
),
memo
,
raise_precision_count
,
reduce_precision_count
);
}
}
auto
status
=
device
::
ascend
::
SelectKernelInfo
(
cnode
);
if
(
status
==
device
::
ascend
::
kStatusRaisePrecision
)
{
(
*
raise_precision_count
)
++
;
}
else
if
(
status
==
device
::
ascend
::
kStatusReducePrecision
)
{
(
*
reduce_precision_count
)
++
;
}
MS_LOG
(
INFO
)
<<
"Select ApplyKernel: "
<<
cnode
->
DebugString
();
}
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
();
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
();
if
(
save_graphs
)
{
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"select_kernel_after"
+
"_graph_"
+
std
::
to_string
(
graph
->
graph_id
())
+
".ir"
;
DumpIR
(
file_path
,
graph
.
get
());
}
MS_LOG
(
INFO
)
<<
"Finish selecting kernel info in graph: "
<<
graph
->
graph_id
();
}
void
AscendSession
::
HardwareOptimize
(
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
const
memo
)
const
{
if
(
memo
->
find
(
graph
)
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
graph
.
get
());
MS_LOG
(
INFO
)
<<
"Start to do HardwareOptimize in graph: "
<<
graph
->
graph_id
();
// convert kernel Graph to model
predictmodel
::
StepConvertGraph
(
graph
.
get
());
HardwareOptimize
(
graph
.
get
());
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
HardwareOptimize
(
NOT_NULL
(
child_graph
),
memo
);
}
MS_LOG
(
INFO
)
<<
"Finish doing HardwareOptimize in graph: "
<<
graph
->
graph_id
();
}
void
AscendSession
::
AssignStaticMemory
(
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
const
memo
)
const
{
if
(
memo
->
find
(
graph
)
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
graph
.
get
());
MS_LOG
(
INFO
)
<<
"Start to assign static memory for parameter in graph: "
<<
graph
->
graph_id
();
// assign static memory for parameters
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id_
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
runtime_instance
->
AssignStaticMemoryInput
(
graph
.
get
().
get
());
runtime_instance
->
AssignStaticMemoryValueNode
(
graph
.
get
().
get
());
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
AssignStaticMemory
(
NOT_NULL
(
child_graph
),
memo
);
}
MS_LOG
(
INFO
)
<<
"Finish assigning static memory for parameter in graph: "
<<
graph
->
graph_id
();
}
void
AscendSession
::
UpdateRefOutputMap
(
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
const
memo
)
const
{
if
(
memo
->
find
(
graph
)
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
graph
.
get
());
for
(
auto
&
child_graph
:
graph
->
child_graph_order
())
{
UpdateRefOutputMap
(
NOT_NULL
(
child_graph
),
memo
);
// copy ref map to final graph
auto
child_ref_map
=
child_graph
->
GetRefMap
();
for
(
auto
&
item
:
child_ref_map
)
{
if
(
graph
->
IsInRefOutputMap
(
item
.
first
))
{
MS_LOG
(
WARNING
)
<<
"The ref pair <"
<<
item
.
first
.
first
->
DebugString
()
<<
", "
<<
item
.
first
.
second
<<
"> is already in "
<<
graph
->
ToString
();
continue
;
}
graph
->
AddRefCorrespondPairs
(
item
.
first
,
item
.
second
);
}
}
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/backend/session/ascend_session.h
浏览文件 @
130cc296
...
...
@@ -151,6 +151,15 @@ class AscendSession : public SessionBasic {
// sync intial tensors' data to device
void
SyncInitialTenosrToDevice
();
void
SetFinalGraphSummaryFlag
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
);
// create parameter to receive data from multiple branch output
void
CreateMultiBranchOutput
(
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
void
SelectKernel
(
NotNull
<
KernelGraphPtr
>
root_graph
);
void
RecurseSelectKernelInfo
(
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
const
memo
,
size_t
*
const
raise_precision_count
,
size_t
*
const
reduce_precision_count
)
const
;
void
IrFusionPass
(
const
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
void
HardwareOptimize
(
const
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
const
;
void
AssignStaticMemory
(
const
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
const
;
void
UpdateRefOutputMap
(
const
NotNull
<
KernelGraphPtr
>
graph
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
const
;
// member variables
// key is final_graph_id,value is child graph execute order of final graph
...
...
mindspore/ccsrc/backend/session/kernel_graph.cc
浏览文件 @
130cc296
...
...
@@ -616,8 +616,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
if
(
AnfAlgo
::
HasNodeAttr
(
kControlDependMode
,
cnode
))
{
depend_mode
=
AnfAlgo
::
GetNodeAttr
<
int
>
(
cnode
,
kControlDependMode
);
}
MS_LOG
(
INFO
)
<<
"Prior node["
<<
prior_node
->
DebugString
()
<<
"], depend node["
<<
depend_node
->
DebugString
()
<<
"], depend_mode :"
<<
depend_mode
<<
"."
;
MS_LOG
(
DEBUG
)
<<
"Prior node["
<<
prior_node
->
DebugString
()
<<
"], depend node["
<<
depend_node
->
DebugString
()
<<
"], depend_mode :"
<<
depend_mode
<<
"."
;
if
(
prior_node
->
isa
<
Parameter
>
()
&&
depend_mode
==
1
)
{
prior_nodes
=
GetOutputNodes
(
prior_node
);
}
...
...
@@ -647,7 +647,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
}
MS_EXCEPTION_IF_NULL
(
first_node
);
MS_EXCEPTION_IF_NULL
(
second_node
);
MS_LOG
(
INFO
)
<<
"Add first node:"
<<
first_node
->
DebugString
()
<<
",second node:"
<<
second_node
->
DebugString
();
MS_LOG
(
DEBUG
)
<<
"Add first node:"
<<
first_node
->
DebugString
()
<<
",second node:"
<<
second_node
->
DebugString
();
AddDependEdge
(
second_node
,
first_node
,
1
);
}
}
...
...
@@ -991,6 +992,30 @@ bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const {
return
false
;
}
void
KernelGraph
::
UpdateChildGraphOrder
()
{
MS_LOG
(
INFO
)
<<
"Update "
<<
ToString
()
<<
" child graph order."
;
SetExecOrderByDefault
();
auto
call_nodes
=
FindNodeByPrimitive
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimCall
->
name
()));
std
::
vector
<
KernelGraphPtr
>
child_graph_order
;
for
(
auto
&
call_node
:
call_nodes
)
{
MS_EXCEPTION_IF_NULL
(
call_node
);
auto
call_child_graphs
=
AnfAlgo
::
GetCallNodeKernelGraph
(
call_node
->
cast
<
CNodePtr
>
());
for
(
const
auto
&
child_graph
:
call_child_graphs
)
{
MS_EXCEPTION_IF_NULL
(
child_graph
);
if
(
child_graph
!=
parent_graph_
)
{
auto
shared_this
=
std
::
dynamic_pointer_cast
<
KernelGraph
>
(
shared_from_this
());
MS_EXCEPTION_IF_NULL
(
shared_this
);
child_graph
->
set_parent_graph
(
shared_this
);
}
child_graph_order
.
push_back
(
child_graph
);
}
}
for
(
size_t
i
=
0
;
i
<
child_graph_order
.
size
();
++
i
)
{
MS_LOG
(
INFO
)
<<
"Child graph["
<<
i
<<
"][id:"
<<
child_graph_order
[
i
]
->
graph_id
()
<<
"]"
;
}
child_graph_order_
=
child_graph_order
;
}
std
::
string
KernelGraph
::
ToString
()
const
{
return
std
::
string
(
"kernel_graph_"
).
append
(
std
::
to_string
(
graph_id_
));
}
KernelGraph
::~
KernelGraph
()
{
device
::
KernelRuntimeManager
::
Instance
().
ClearGraphResource
(
graph_id_
);
}
...
...
mindspore/ccsrc/backend/session/kernel_graph.h
浏览文件 @
130cc296
...
...
@@ -156,6 +156,12 @@ class KernelGraph : public FuncGraph {
bool
IsFinalOutputKernel
(
const
AnfNodePtr
&
node
)
const
;
uint32_t
current_epoch
()
const
{
return
current_epoch_
;
}
void
set_current_epoch
(
uint32_t
epoch
)
{
current_epoch_
=
epoch
;
}
void
UpdateChildGraphOrder
();
const
std
::
vector
<
AnfNodePtr
>
&
child_graph_result
()
const
{
return
child_graph_result_
;
}
void
AddChildGraphResult
(
const
AnfNodePtr
&
parameter
)
{
child_graph_result_
.
push_back
(
parameter
);
}
void
set_child_graph_result
(
const
std
::
vector
<
AnfNodePtr
>
&
child_graph_result
)
{
child_graph_result_
=
child_graph_result
;
}
private:
// remove value node form graph
...
...
@@ -173,6 +179,7 @@ class KernelGraph : public FuncGraph {
void
UpdateControlDependRelations
(
const
std
::
vector
<
AnfNodePtr
>
&
depends
);
std
::
shared_ptr
<
std
::
vector
<
AnfNodePtr
>>
inputs_
;
std
::
vector
<
AnfNodePtr
>
child_graph_result_
;
std
::
vector
<
CNodePtr
>
execution_order_
;
uint32_t
graph_id_
;
uint32_t
stream_distinction_label_
;
...
...
mindspore/ccsrc/backend/session/session_basic.cc
浏览文件 @
130cc296
...
...
@@ -74,7 +74,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
return
input_tensors
[
input_idx
];
}
}
MS_LOG
(
EXCEPTION
)
<<
"Parameter : "
<<
node
->
DebugString
()
<<
"has no output addr"
;
MS_LOG
(
EXCEPTION
)
<<
"Parameter : "
<<
node
->
DebugString
()
<<
"
has no output addr"
;
}
}
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
...
...
@@ -107,8 +107,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
return
tensor
;
}
BaseRef
CreatTensorForOutput
(
const
AnfNodePtr
&
anf
,
const
KernelGraph
&
graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
{
BaseRef
Creat
e
TensorForOutput
(
const
AnfNodePtr
&
anf
,
const
KernelGraph
&
graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
{
MS_EXCEPTION_IF_NULL
(
anf
);
MS_LOG
(
INFO
)
<<
"Create tensor for output["
<<
anf
->
DebugString
()
<<
"]"
;
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
anf
,
0
);
...
...
@@ -120,7 +120,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
MS_EXCEPTION_IF_NULL
(
cnode
);
VectorRef
ret
;
for
(
size_t
i
=
1
;
i
<
cnode
->
inputs
().
size
();
++
i
)
{
auto
out
=
CreatTensorForOutput
(
cnode
->
input
(
i
),
graph
,
input_tensors
);
auto
out
=
Creat
e
TensorForOutput
(
cnode
->
input
(
i
),
graph
,
input_tensors
);
ret
.
push_back
(
out
);
}
return
ret
;
...
...
@@ -133,25 +133,6 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
return
CreateOneTensor
(
item_with_index
.
first
,
item_with_index
.
second
,
graph
,
input_tensors
);
}
BaseRef
CreatTupleForOutput
(
const
AnfNodePtr
&
anf
,
const
KernelGraph
&
graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
{
MS_EXCEPTION_IF_NULL
(
anf
);
if
(
!
AnfAlgo
::
IsRealKernel
(
anf
))
{
MS_LOG
(
EXCEPTION
)
<<
"Anf["
<<
anf
->
DebugString
()
<<
"] should be a executable kernel"
;
}
if
(
anf
->
isa
<
ValueNode
>
())
{
return
CreateOneTensor
(
anf
,
0
,
graph
,
input_tensors
);
}
VectorRef
ret
;
if
(
anf
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
anf
)
!=
prim
::
kPrimMakeTuple
->
name
())
{
for
(
size_t
i
=
0
;
i
<
AnfAlgo
::
GetOutputTensorNum
(
anf
);
++
i
)
{
auto
out
=
CreateOneTensor
(
anf
,
i
,
graph
,
input_tensors
);
ret
.
emplace_back
(
out
);
}
}
return
ret
;
}
ValueNodePtr
CreateNewValueNode
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
anf
);
MS_EXCEPTION_IF_NULL
(
graph
);
...
...
@@ -880,20 +861,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
const
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
outputs
);
if
(
!
kernel_graph
->
child_graph_order
().
empty
())
{
// use the last child graph output as the root graph output
UpdateOutputs
(
kernel_graph
->
child_graph_order
().
back
(),
outputs
,
input_tensors
);
return
;
}
auto
anf_outputs
=
kernel_graph
->
outputs
();
for
(
auto
&
item
:
anf_outputs
)
{
MS_EXCEPTION_IF_NULL
(
item
);
MS_LOG
(
INFO
)
<<
"Update output["
<<
item
->
DebugString
()
<<
"]"
;
if
(
AnfAlgo
::
IsTupleOutput
(
item
)
&&
AnfAlgo
::
IsRealKernel
(
item
))
{
outputs
->
emplace_back
(
CreatTupleForOutput
(
item
,
*
kernel_graph
,
input_tensors
));
continue
;
}
outputs
->
emplace_back
(
CreatTensorForOutput
(
item
,
*
kernel_graph
,
input_tensors
));
outputs
->
emplace_back
(
CreateTensorForOutput
(
item
,
*
kernel_graph
,
input_tensors
));
}
}
...
...
mindspore/ccsrc/runtime/device/kernel_runtime.cc
浏览文件 @
130cc296
...
...
@@ -294,6 +294,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL
(
mem_manager_
);
auto
graph_inputs
=
graph
->
inputs
();
auto
graph_valid_input
=
graph
->
valid_inputs
();
graph_inputs
.
insert
(
graph_inputs
.
end
(),
graph
->
child_graph_result
().
begin
(),
graph
->
child_graph_result
().
end
());
std
::
vector
<
AnfNodePtr
>
need_alloc_nodes
;
for
(
size_t
i
=
0
;
i
<
graph_inputs
.
size
();
++
i
)
{
auto
item
=
graph_inputs
[
i
];
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
130cc296
...
...
@@ -240,6 +240,7 @@ constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
constexpr
auto
kAttrOffset
=
"offset"
;
constexpr
auto
kAttrPsKey
=
"ps_key"
;
constexpr
auto
kAttrOptimizerType
=
"optim_type"
;
constexpr
auto
kAttrChildGraph
=
"child_graph"
;
// attr value
constexpr
auto
kValueTargetSwitch
=
"target_switch"
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录