Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
485ac838
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
485ac838
编写于
7月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3162 split tuple output node to maketuple
Merge pull request !3162 from lianliguang/split-tuple-node-to-make-tuple
上级
f19aeaa0
d10d1a17
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
195 addition
and
170 deletion
+195
-170
mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc
...c/backend/optimizer/common/common_backend_optimization.cc
+1
-1
mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc
...nd/optimizer/pass/convert_tuple_input_to_dynamic_input.cc
+17
-71
mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc
...ckend/optimizer/pass/convert_tuple_output_to_maketuple.cc
+46
-8
mindspore/ccsrc/backend/session/kernel_graph.cc
mindspore/ccsrc/backend/session/kernel_graph.cc
+123
-88
mindspore/ccsrc/backend/session/kernel_graph.h
mindspore/ccsrc/backend/session/kernel_graph.h
+7
-1
tests/ut/cpp/session/kernel_graph_test.cc
tests/ut/cpp/session/kernel_graph_test.cc
+1
-1
未找到文件。
mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc
浏览文件 @
485ac838
...
...
@@ -47,8 +47,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
common_pm
->
AddPass
(
std
::
make_shared
<
ConvertConstInputToAttr
>
());
common_pm
->
AddPass
(
std
::
make_shared
<
ConstToAttrStridedSliceGradPass
>
());
common_pm
->
AddPass
(
std
::
make_shared
<
ConvertConstInputToTensorInput
>
());
common_pm
->
AddPass
(
std
::
make_shared
<
ConvertTupleInputToDynamicInput
>
());
common_pm
->
AddPass
(
std
::
make_shared
<
ConvertTupleOutputToMaketuple
>
());
common_pm
->
AddPass
(
std
::
make_shared
<
ConvertTupleInputToDynamicInput
>
());
optimizer
->
AddPassManager
(
common_pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
...
...
mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc
浏览文件 @
485ac838
...
...
@@ -27,86 +27,33 @@
namespace
mindspore
{
namespace
opt
{
namespace
{
bool
MakeValueNode
(
const
AnfNodePtr
&
node
)
{
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
if
(
value_node
==
nullptr
)
{
return
false
;
}
// create kernel_info fo new value node
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
value_node
->
set_kernel_info
(
kernel_info
);
// create kernel_build_info for new value node
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
{
kOpFormat_DEFAULT
});
// set value node initial device data type = infer data type
TypeId
infer_data_type
;
if
(
AnfAlgo
::
GetOutputTensorNum
(
value_node
)
==
0
)
{
infer_data_type
=
kTypeUnknown
;
}
else
{
infer_data_type
=
AnfAlgo
::
GetOutputInferDataType
(
value_node
,
0
);
}
kernel_build_info_builder
->
SetOutputsDeviceType
(
std
::
vector
<
TypeId
>
{
infer_data_type
});
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
value_node
.
get
());
return
true
;
}
void
ConvertTupleOuputToPlantInputs
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
input_node
,
std
::
vector
<
AnfNodePtr
>
*
plant_inputs
,
std
::
vector
<
int
>
*
dyn_input_sizes
)
{
MS_EXCEPTION_IF_NULL
(
plant_inputs
);
MS_EXCEPTION_IF_NULL
(
dyn_input_sizes
);
MS_EXCEPTION_IF_NULL
(
graph
);
auto
output_size
=
AnfAlgo
::
GetOutputTensorNum
(
input_node
);
dyn_input_sizes
->
push_back
(
output_size
);
std
::
vector
<
AnfNodePtr
>
convert_inputs
;
auto
kernel_graph
=
graph
->
cast
<
KernelGraphPtr
>
();
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
if
(
input_node
->
isa
<
ValueNode
>
())
{
auto
value_node
=
input_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
convert_inputs
=
kernel_graph
->
SplitTupleValueNodeToNodeList
(
value_node
);
}
else
{
for
(
size_t
index
=
0
;
index
<
output_size
;
++
index
)
{
auto
tuple_get_item
=
CreatTupleGetItemNode
(
graph
,
input_node
,
index
);
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
input_node
,
index
)},
{
AnfAlgo
::
GetOutputInferShape
(
input_node
,
index
)},
tuple_get_item
.
get
());
convert_inputs
.
emplace_back
(
tuple_get_item
);
}
}
(
void
)
std
::
copy
(
convert_inputs
.
begin
(),
convert_inputs
.
end
(),
std
::
back_inserter
(
*
plant_inputs
));
}
void
ConvertMakeTupleInputToPlantInputs
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
cnode_ptr
)
{
MS_EXCEPTION_IF_NULL
(
cnode_ptr
);
MS_EXCEPTION_IF_NULL
(
graph
);
auto
&
ori_args
=
cnode_ptr
->
inputs
();
if
(
ori_args
.
size
()
<
1
)
{
return
;
}
std
::
vector
<
AnfNodePtr
>
plant_inputs
;
std
::
vector
<
int
>
dyn_input_sizes
;
plant_inputs
.
push_back
(
ori_args
[
kAnfPrimitiveIndex
]);
for
(
size_t
i
=
1
;
i
<
ori_args
.
size
();
++
i
)
{
auto
input_node
=
ori_args
[
i
];
if
(
IsPrimitiveCNode
(
input_node
,
prim
::
kPrimMakeTuple
))
{
plant_inputs
.
push_back
(
AnfAlgo
::
GetCNodePrimitiveNode
(
cnode_ptr
));
for
(
size_t
i
=
0
;
i
<
AnfAlgo
::
GetInputTensorNum
(
cnode_ptr
);
++
i
)
{
auto
input_node
=
AnfAlgo
::
GetInputNode
(
cnode_ptr
,
i
);
MS_EXCEPTION_IF_NULL
(
input_node
);
if
(
input_node
->
isa
<
CNode
>
()
&&
AnfAlgo
::
CheckPrimitiveType
(
input_node
,
prim
::
kPrimMakeTuple
))
{
auto
input_size
=
AnfAlgo
::
GetOutputTensorNum
(
input_node
);
dyn_input_sizes
.
push_back
(
input_size
);
auto
cnode
=
input_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
inputs
=
cnode
->
inputs
();
for
(
size_t
j
=
1
;
j
<
inputs
.
size
();
++
j
)
{
MS_EXCEPTION_IF_NULL
(
inputs
[
j
]);
if
(
IsValueNode
<
tensor
::
Tensor
>
(
inputs
[
j
]))
{
auto
success
=
MakeValueNode
(
inputs
[
j
]);
auto
make_tuple
=
input_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
make_tuple
);
for
(
size_t
j
=
0
;
j
<
AnfAlgo
::
GetInputTensorNum
(
make_tuple
);
++
j
)
{
auto
dyn_input_node
=
AnfAlgo
::
GetInputNode
(
make_tuple
,
j
);
MS_EXCEPTION_IF_NULL
(
dyn_input_node
);
if
(
IsValueNode
<
tensor
::
Tensor
>
(
dyn_input_node
))
{
auto
kernel_graph
=
graph
->
cast
<
KernelGraphPtr
>
();
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
success
=
kernel_graph
->
NewValueNode
(
dyn_input_node
->
cast
<
ValueNodePtr
>
());
if
(
!
success
)
{
MS_LOG
(
WARNING
)
<<
"Make value node failed, "
<<
inputs
[
j
]
->
DebugString
();
MS_LOG
(
WARNING
)
<<
"Make value node failed, "
<<
dyn_input_node
->
DebugString
();
}
}
plant_inputs
.
push_back
(
inputs
[
j
]
);
plant_inputs
.
push_back
(
dyn_input_node
);
}
}
else
if
(
input_node
->
Type
()
!=
nullptr
&&
AnfAlgo
::
IsTupleOutput
(
input_node
))
{
ConvertTupleOuputToPlantInputs
(
graph
,
input_node
,
&
plant_inputs
,
&
dyn_input_sizes
);
}
else
{
dyn_input_sizes
.
push_back
(
-
1
);
plant_inputs
.
push_back
(
input_node
);
...
...
@@ -139,9 +86,8 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu
for
(
auto
&
t
:
todos
)
{
ConvertMakeTupleInputToPlantInputs
(
sub_graph
,
t
->
cast
<
CNodePtr
>
());
}
}
else
{
ConvertMakeTupleInputToPlantInputs
(
func_graph
,
node
->
cast
<
CNodePtr
>
());
}
ConvertMakeTupleInputToPlantInputs
(
func_graph
,
node
->
cast
<
CNodePtr
>
());
return
node
;
}
}
// namespace opt
...
...
mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc
浏览文件 @
485ac838
...
...
@@ -25,6 +25,38 @@
namespace
mindspore
{
namespace
opt
{
namespace
{
CNodePtr
ConvertTupleOuputToPlantInputs
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
input_node
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
if
(
!
AnfAlgo
::
IsTupleOutput
(
input_node
))
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot using the function to convert a not tuple output node to maketuple!"
;
}
if
(
input_node
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"The function can only split a parameter or valuenode bug got "
<<
input_node
->
DebugString
();
}
std
::
vector
<
AnfNodePtr
>
convert_inputs
=
{
NewValueNode
(
prim
::
kPrimMakeTuple
)};
auto
kernel_graph
=
graph
->
cast
<
KernelGraphPtr
>
();
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
splited_node_list
=
kernel_graph
->
SplitTupleOutputNodeToNodeList
(
input_node
);
for
(
const
auto
&
node
:
splited_node_list
)
{
if
(
AnfAlgo
::
IsTupleOutput
(
node
))
{
convert_inputs
.
emplace_back
(
ConvertTupleOuputToPlantInputs
(
graph
,
node
));
continue
;
}
convert_inputs
.
emplace_back
(
node
);
}
auto
make_tuple
=
graph
->
NewCNode
(
convert_inputs
);
std
::
vector
<
abstract
::
AbstractBasePtr
>
abstract_list
;
auto
make_tuple_input_size
=
AnfAlgo
::
GetInputTensorNum
(
make_tuple
);
for
(
size_t
index
=
0
;
index
<
make_tuple_input_size
;
++
index
)
{
auto
make_tuple_input
=
AnfAlgo
::
GetInputNode
(
make_tuple
,
index
);
MS_EXCEPTION_IF_NULL
(
make_tuple_input
);
abstract_list
.
emplace_back
(
make_tuple_input
->
abstract
());
}
make_tuple
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
));
return
make_tuple
;
}
CNodePtr
ConvertTupleInputToMakeTuple
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
cnode_ptr
)
{
MS_EXCEPTION_IF_NULL
(
cnode_ptr
);
MS_EXCEPTION_IF_NULL
(
graph
);
...
...
@@ -35,19 +67,25 @@ CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr
std
::
vector
<
TypeId
>
types
;
std
::
vector
<
std
::
vector
<
size_t
>>
shapes
;
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs_list
=
{
NewValueNode
(
prim
::
kPrimMakeTuple
)};
for
(
size_t
tuple_out_index
=
0
;
tuple_out_index
<
AnfAlgo
::
GetOutputTensorNum
(
input_node
);
++
tuple_out_index
)
{
make_tuple_inputs_list
.
emplace_back
(
CreatTupleGetItemNode
(
graph
,
input_node
,
tuple_out_index
));
types
.
push_back
(
AnfAlgo
::
GetOutputInferDataType
(
input_node
,
tuple_out_index
));
shapes
.
emplace_back
(
AnfAlgo
::
GetOutputInferShape
(
input_node
,
tuple_out_index
));
if
(
input_node
->
isa
<
CNode
>
())
{
for
(
size_t
tuple_out_index
=
0
;
tuple_out_index
<
AnfAlgo
::
GetOutputTensorNum
(
input_node
);
++
tuple_out_index
)
{
make_tuple_inputs_list
.
emplace_back
(
CreatTupleGetItemNode
(
graph
,
input_node
,
tuple_out_index
));
types
.
push_back
(
AnfAlgo
::
GetOutputInferDataType
(
input_node
,
tuple_out_index
));
shapes
.
emplace_back
(
AnfAlgo
::
GetOutputInferShape
(
input_node
,
tuple_out_index
));
}
auto
make_tuple
=
graph
->
NewCNode
(
make_tuple_inputs_list
);
AnfAlgo
::
SetOutputInferTypeAndShape
(
types
,
shapes
,
make_tuple
.
get
());
convert_inputs
.
emplace_back
(
make_tuple
);
continue
;
}
auto
make_tuple
=
graph
->
NewCNode
(
make_tuple_inputs_list
);
AnfAlgo
::
SetOutputInferTypeAndShape
(
types
,
shapes
,
make_tuple
.
get
());
convert_inputs
.
emplace_back
(
make_tuple
);
convert_inputs
.
emplace_back
(
ConvertTupleOuputToPlantInputs
(
graph
,
input_node
));
}
else
{
convert_inputs
.
push_back
(
input_node
);
}
}
return
graph
->
NewCNode
(
convert_inputs
);
auto
new_node
=
graph
->
NewCNode
(
convert_inputs
);
new_node
->
set_abstract
(
cnode_ptr
->
abstract
());
return
new_node
;
}
}
// namespace
...
...
mindspore/ccsrc/backend/session/kernel_graph.cc
浏览文件 @
485ac838
...
...
@@ -79,31 +79,6 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
return
real_inputs
;
}
AnfNodePtr
MakeValueNode
(
const
AnfNodePtr
&
node
)
{
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
if
(
value_node
==
nullptr
)
{
return
nullptr
;
}
ValueNodePtr
new_value_node
=
std
::
make_shared
<
ValueNode
>
(
value_node
->
value
());
new_value_node
->
set_abstract
(
value_node
->
abstract
());
// create kernel_info fo new value node
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
new_value_node
->
set_kernel_info
(
kernel_info
);
// create kernel_build_info for new value node
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
{
kOpFormat_DEFAULT
});
// set value node initial device data type = infer data type
std
::
vector
<
TypeId
>
types
;
for
(
size_t
index
=
0
;
index
<
AnfAlgo
::
GetOutputTensorNum
(
value_node
);
++
index
)
{
types
.
push_back
(
kTypeUnknown
);
}
kernel_build_info_builder
->
SetOutputsDeviceType
(
types
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
new_value_node
.
get
());
return
new_value_node
;
}
bool
IsSameLabel
(
const
CNodePtr
&
left
,
const
CNodePtr
&
right
)
{
if
(
left
==
right
)
{
return
true
;
...
...
@@ -121,6 +96,18 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
return
false
;
}
}
// namespace
AnfNodePtr
KernelGraph
::
MakeValueNode
(
const
AnfNodePtr
&
node
)
{
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
if
(
value_node
==
nullptr
)
{
return
nullptr
;
}
ValueNodePtr
new_value_node
=
std
::
make_shared
<
ValueNode
>
(
value_node
->
value
());
new_value_node
->
set_abstract
(
value_node
->
abstract
());
this
->
SetKernelInfoForNode
(
new_value_node
);
return
new_value_node
;
}
std
::
vector
<
AnfNodePtr
>
KernelGraph
::
outputs
()
const
{
auto
graph_output
=
output
();
if
(
IsPrimitiveCNode
(
graph_output
,
prim
::
kPrimMakeTuple
))
{
...
...
@@ -290,28 +277,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
MS_EXCEPTION_IF_NULL
(
cnode
);
cnode
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractNone
>
());
CreateKernelInfoFromNewParameter
(
cnode
);
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
std
::
vector
<
size_t
>
feature_map_input_indexs
;
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
for
(
size_t
index
=
1
;
index
<
inputs
.
size
();
++
index
)
{
auto
node
=
inputs
[
index
];
if
(
AnfAlgo
::
IsFeatureMapOutput
(
node
))
{
feature_map_input_indexs
.
push_back
(
index
);
}
}
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
prim
::
kPrimCast
->
name
())
{
AnfAlgo
::
SetNodeAttr
(
kIsBackendCast
,
MakeValue
(
false
),
cnode
);
}
if
(
inputs
.
size
()
==
1
||
!
feature_map_input_indexs
.
empty
())
{
kernel_info
->
SetFeatureMapFlag
(
true
);
}
if
(
AnfAlgo
::
IsRealKernel
(
cnode
))
{
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapOutput
,
MakeValue
(
kernel_info
->
is_feature_map
()),
cnode
);
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapInputList
,
MakeValue
(
feature_map_input_indexs
),
cnode
);
}
cnode
->
set_kernel_info
(
kernel_info
);
SetKernelInfoForNode
(
cnode
);
AnfAlgo
::
SetGraphId
(
graph_id_
,
cnode
.
get
());
return
cnode
;
}
...
...
@@ -351,6 +320,50 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
}
}
void
KernelGraph
::
SetKernelInfoForNode
(
const
AnfNodePtr
&
node
)
const
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
node
->
set_kernel_info
(
kernel_info
);
if
(
node
->
isa
<
CNode
>
())
{
std
::
vector
<
size_t
>
feature_map_input_indexs
;
kernel_info
->
SetFeatureMapFlag
(
false
);
for
(
size_t
index
=
0
;
index
<
AnfAlgo
::
GetInputTensorNum
(
node
);
++
index
)
{
if
(
AnfAlgo
::
IsFeatureMapInput
(
node
,
index
))
{
kernel_info
->
SetFeatureMapFlag
(
true
);
feature_map_input_indexs
.
push_back
(
index
);
}
}
if
(
AnfAlgo
::
GetInputTensorNum
(
node
)
==
0
)
{
kernel_info
->
SetFeatureMapFlag
(
true
);
}
if
(
AnfAlgo
::
IsRealKernel
(
node
))
{
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapOutput
,
MakeValue
(
kernel_info
->
is_feature_map
()),
node
);
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapInputList
,
MakeValue
(
feature_map_input_indexs
),
node
);
}
return
;
}
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// set the format of value_node to DEFAULT_FORMAT
std
::
vector
<
TypeId
>
types
;
kernel_build_info_builder
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
{
kOpFormat_DEFAULT
});
if
(
node
->
isa
<
ValueNode
>
())
{
kernel_info
->
SetFeatureMapFlag
(
false
);
types
.
emplace_back
(
kTypeUnknown
);
}
if
(
node
->
isa
<
Parameter
>
())
{
auto
parameter
=
node
->
cast
<
ParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
parameter
);
bool
is_weight
=
AnfAlgo
::
IsParameterWeight
(
parameter
);
kernel_info
->
SetFeatureMapFlag
(
!
is_weight
);
types
.
push_back
(
is_weight
?
kTypeUnknown
:
AnfAlgo
::
GetOutputInferDataType
(
parameter
,
0
));
}
// set parameter initaial device data type
kernel_build_info_builder
->
SetOutputsDeviceType
(
types
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
node
.
get
());
}
CNodePtr
KernelGraph
::
NewCNode
(
const
CNodePtr
&
cnode
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
new_cnode
=
std
::
make_shared
<
CNode
>
(
*
cnode
);
...
...
@@ -366,75 +379,97 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
}
ParameterPtr
KernelGraph
::
NewParameter
(
const
ParameterPtr
&
parameter
)
{
ParameterPtr
new_parameter
=
add_parameter
();
auto
abstract
=
parameter
==
nullptr
?
std
::
make_shared
<
abstract
::
AbstractNone
>
()
:
parameter
->
abstract
();
auto
new_parameter
=
NewParameter
(
abstract
);
MS_EXCEPTION_IF_NULL
(
new_parameter
);
// create kernel_info form new parameter
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
size_t
output_tensor_num
=
1
;
// if use default parameter = nullptr,it remarks create a new parameter from no parameter
if
(
parameter
==
nullptr
)
{
new_parameter
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractNone
>
());
kernel_info
->
SetFeatureMapFlag
(
true
);
}
else
{
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
new_parameter
->
set_abstract
(
parameter
->
abstract
());
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
if
(
parameter
!=
nullptr
)
{
new_parameter
->
set_name
(
parameter
->
name
());
if
(
AnfAlgo
::
IsParameterWeight
(
parameter
))
{
new_parameter
->
set_default_param
(
parameter
->
default_param
());
kernel_info
->
SetFeatureMapFlag
(
false
);
}
else
{
kernel_info
->
SetFeatureMapFlag
(
true
);
}
}
new_parameter
->
set_kernel_info
(
kernel_info
);
// create kernel_build_info for new parameter
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// create init data type,
std
::
vector
<
TypeId
>
init_data_type
=
{};
TypeId
infer_data_type
=
AnfAlgo
::
GetOutputInferDataType
(
new_parameter
,
0
);
init_data_type
.
push_back
(
AnfAlgo
::
IsParameterWeight
(
new_parameter
)
?
kTypeUnknown
:
infer_data_type
);
// create kernel_info form new parameter
SetKernelInfoForNode
(
new_parameter
);
AnfAlgo
::
SetGraphId
(
graph_id_
,
new_parameter
.
get
());
return
new_parameter
;
}
// set the format of parameter to DEFAULT_FORMAT
kernel_build_info_builder
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
(
output_tensor_num
,
kOpFormat_DEFAULT
));
// set parameter initaial device data type
kernel_build_info_builder
->
SetOutputsDeviceType
(
init_data_type
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
new_parameter
.
get
());
ParameterPtr
KernelGraph
::
NewParameter
(
const
abstract
::
AbstractBasePtr
&
abstract
)
{
ParameterPtr
new_parameter
=
add_parameter
();
new_parameter
->
set_abstract
(
abstract
);
MS_EXCEPTION_IF_NULL
(
new_parameter
);
// create kernel_info form new parameter
SetKernelInfoForNode
(
new_parameter
);
AnfAlgo
::
SetGraphId
(
graph_id_
,
new_parameter
.
get
());
return
new_parameter
;
}
std
::
vector
<
AnfNodePtr
>
KernelGraph
::
SplitTupleParameterToNodeList
(
const
ParameterPtr
&
parameter
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
std
::
vector
<
AnfNodePtr
>
convert_nodes_list
;
auto
abstract
=
parameter
->
abstract
();
MS_EXCEPTION_IF_NULL
(
abstract
);
if
(
!
abstract
->
isa
<
abstract
::
AbstractTuple
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Multiple output Parameter's output must be a tuple abstract but got "
<<
abstract
->
ToString
();
}
auto
tuple_abstract
=
abstract
->
cast
<
abstract
::
AbstractTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_abstract
);
for
(
size_t
index
=
0
;
index
<
tuple_abstract
->
size
();
++
index
)
{
auto
new_parameter
=
this
->
NewParameter
((
*
tuple_abstract
)[
index
]);
SetKernelInfoForNode
(
new_parameter
);
convert_nodes_list
.
emplace_back
(
new_parameter
);
}
auto
new_inputs
=
std
::
make_shared
<
std
::
vector
<
AnfNodePtr
>>
();
auto
old_inputs
=
inputs
();
for
(
const
auto
&
input_node
:
old_inputs
)
{
if
(
input_node
!=
parameter
)
{
new_inputs
->
emplace_back
(
input_node
);
continue
;
}
std
::
copy
(
convert_nodes_list
.
begin
(),
convert_nodes_list
.
end
(),
std
::
back_inserter
(
*
new_inputs
));
}
inputs_
=
new_inputs
;
return
convert_nodes_list
;
}
std
::
vector
<
AnfNodePtr
>
KernelGraph
::
SplitTupleOutputNodeToNodeList
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"The function can only split a parameter or valuenode bug got "
<<
node
->
DebugString
();
}
if
(
node
->
isa
<
Parameter
>
())
{
return
SplitTupleParameterToNodeList
(
node
->
cast
<
ParameterPtr
>
());
}
return
SplitTupleValueNodeToNodeList
(
node
->
cast
<
ValueNodePtr
>
());
}
std
::
vector
<
AnfNodePtr
>
KernelGraph
::
SplitTupleValueNodeToNodeList
(
const
ValueNodePtr
&
value_node
)
{
MS_EXCEPTION_IF_NULL
(
value_node
);
auto
node_value
=
value_node
->
value
();
auto
output_size
=
AnfAlgo
::
GetOutputTensorNum
(
value_node
);
std
::
vector
<
AnfNodePtr
>
convert_inputs
;
if
(
!
node_value
->
isa
<
ValueTuple
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Multiple output valuenode's value must be a value tuple but got "
<<
node_value
->
ToString
();
}
auto
value_tuple
=
node_value
->
cast
<
ValueTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_tuple
);
if
(
value_tuple
->
size
()
!=
output_size
)
{
MS_LOG
(
EXCEPTION
)
<<
"Value tuple size"
<<
value_tuple
->
size
()
<<
" is not mathced with the value node's output size"
<<
output_size
;
auto
abstract
=
value_node
->
abstract
();
if
(
!
abstract
->
isa
<
abstract
::
AbstractTuple
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Spilted node's output abstract is not type tuple"
;
}
auto
tuple_abstract
=
abstract
->
cast
<
abstract
::
AbstractTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_abstract
);
if
(
tuple_abstract
->
size
()
!=
value_tuple
->
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The node output index ["
<<
value_tuple
->
size
()
<<
"]is outof range "
<<
tuple_abstract
->
size
();
}
for
(
size_t
index
=
0
;
index
<
value_tuple
->
value
().
size
();
++
index
)
{
auto
new_value_node
=
std
::
make_shared
<
ValueNode
>
(
value_tuple
->
value
()[
index
]);
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
value_node
,
index
)},
{
AnfAlgo
::
GetOutputInferShape
(
value_node
,
index
)},
new_value_node
.
get
());
new_value_node
->
set_abstract
((
*
tuple_abstract
)[
index
]);
AddValueNodeToGraph
(
new_value_node
);
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
new_value_node
->
set_kernel_info
(
kernel_info
);
kernel_info
->
SetFeatureMapFlag
(
false
);
// create kernel_build_info for new value node
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder
->
SetOutputsFormat
({
kOpFormat_DEFAULT
});
// set value node initial device data type = infer data type
kernel_build_info_builder
->
SetOutputsDeviceType
({
kTypeUnknown
});
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
new_value_node
.
get
());
SetKernelInfoForNode
(
new_value_node
);
AnfAlgo
::
SetGraphId
(
graph_id_
,
new_value_node
.
get
());
AddValueNodeToGraph
(
new_value_node
);
convert_inputs
.
emplace_back
(
new_value_node
);
}
if
(
!
RemoveValueNodeFromGraph
(
value_node
))
{
...
...
mindspore/ccsrc/backend/session/kernel_graph.h
浏览文件 @
485ac838
...
...
@@ -54,8 +54,10 @@ class KernelGraph : public FuncGraph {
void
CreateKernelInfoFromNewParameter
(
const
CNodePtr
&
cnode
);
CNodePtr
NewCNode
(
const
CNodePtr
&
cnode
);
ParameterPtr
NewParameter
(
const
ParameterPtr
&
parameter
=
nullptr
);
ParameterPtr
NewParameter
(
const
abstract
::
AbstractBasePtr
&
abstract
);
ValueNodePtr
NewValueNode
(
const
ValuePtr
&
value
);
ValueNodePtr
NewValueNode
(
const
ValueNodePtr
&
value_node
=
nullptr
);
std
::
vector
<
AnfNodePtr
>
SplitTuple
ValueNodeToNodeList
(
const
ValueNodePtr
&
value_
node
);
std
::
vector
<
AnfNodePtr
>
SplitTuple
OutputNodeToNodeList
(
const
AnfNodePtr
&
node
);
void
set_execution_order
(
const
std
::
vector
<
CNodePtr
>
&
order
)
{
execution_order_
=
order
;
}
const
std
::
vector
<
CNodePtr
>
&
execution_order
()
const
{
return
execution_order_
;
}
void
SetExecOrderByDefault
();
...
...
@@ -166,6 +168,10 @@ class KernelGraph : public FuncGraph {
private:
// remove value node form graph
bool
RemoveValueNodeFromGraph
(
const
ValueNodePtr
&
value_node
);
void
SetKernelInfoForNode
(
const
AnfNodePtr
&
node
)
const
;
std
::
vector
<
AnfNodePtr
>
SplitTupleValueNodeToNodeList
(
const
ValueNodePtr
&
value_node
);
std
::
vector
<
AnfNodePtr
>
SplitTupleParameterToNodeList
(
const
ParameterPtr
&
parameter
);
AnfNodePtr
MakeValueNode
(
const
AnfNodePtr
&
node
);
void
VisitNodeDescendants
(
const
AnfNodePtr
&
node
,
std
::
queue
<
AnfNodePtr
>
*
visit_queue
,
std
::
unordered_set
<
AnfNodePtr
>
*
visited_nodes
);
// update node edge list
...
...
tests/ut/cpp/session/kernel_graph_test.cc
浏览文件 @
485ac838
...
...
@@ -60,7 +60,7 @@ TEST_F(KernelGraphTest, NewParameter) {
auto
anf_graph
=
std
::
make_shared
<
FuncGraph
>
();
auto
kernel_graph
=
std
::
make_shared
<
KernelGraph
>
();
// test nullptr as input
auto
new_paramter
=
kernel_graph
->
NewParameter
(
nullptr
);
auto
new_paramter
=
kernel_graph
->
NewParameter
();
EXPECT_NE
(
new_paramter
,
nullptr
);
EXPECT_TRUE
(
new_paramter
->
isa
<
Parameter
>
());
EXPECT_EQ
(
AnfAlgo
::
GetOutputFormat
(
new_paramter
,
0
),
kOpFormat_DEFAULT
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录