Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e121bcd3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e121bcd3
编写于
9月 01, 2020
作者:
Z
zhengjun10
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
constant fold approve multi output
上级
f42b3bbf
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
289 addition
and
57 deletion
+289
-57
mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc
...ut/tools/optimizer/fusion/constant_folding_fusion_test.cc
+145
-0
mindspore/lite/tools/optimizer/common/gllo_utils.cc
mindspore/lite/tools/optimizer/common/gllo_utils.cc
+48
-3
mindspore/lite/tools/optimizer/common/gllo_utils.h
mindspore/lite/tools/optimizer/common/gllo_utils.h
+2
-0
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc
...re/lite/tools/optimizer/fusion/constant_folding_fusion.cc
+94
-54
未找到文件。
mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc
浏览文件 @
e121bcd3
...
...
@@ -236,6 +236,136 @@ MetaGraphTptr BuildMixGraph() {
// final output
return
meta_graph
;
}
MetaGraphTptr
BuildSplitGraph
()
{
auto
meta_graph
=
std
::
make_shared
<
schema
::
MetaGraphT
>
();
meta_graph
->
name
=
"graph"
;
// slice node
auto
split_node
=
std
::
make_unique
<
schema
::
CNodeT
>
();
split_node
->
inputIndex
=
{
0
};
split_node
->
outputIndex
=
{
1
,
2
};
split_node
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
split_node
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Split
;
std
::
unique_ptr
<
schema
::
SplitT
>
attr
=
std
::
make_unique
<
schema
::
SplitT
>
();
attr
->
numberSplit
=
2
;
attr
->
splitDim
=
1
;
split_node
->
primitive
->
value
.
value
=
attr
.
release
();
split_node
->
name
=
"split"
;
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
split_node
));
meta_graph
->
inputIndex
=
{
0
,
3
,
4
};
meta_graph
->
outputIndex
=
{
5
,
6
};
auto
mul_node1
=
std
::
make_unique
<
schema
::
CNodeT
>
();
mul_node1
->
inputIndex
=
{
1
,
3
};
mul_node1
->
outputIndex
=
{
5
};
mul_node1
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
mul_node1
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Mul
;
std
::
unique_ptr
<
schema
::
MulT
>
mul_attr
=
std
::
make_unique
<
schema
::
MulT
>
();
mul_node1
->
primitive
->
value
.
value
=
mul_attr
.
release
();
mul_node1
->
name
=
"mul1"
;
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
mul_node1
));
auto
mul_node2
=
std
::
make_unique
<
schema
::
CNodeT
>
();
mul_node2
->
inputIndex
=
{
2
,
4
};
mul_node2
->
outputIndex
=
{
6
};
mul_node2
->
primitive
=
std
::
make_unique
<
schema
::
PrimitiveT
>
();
mul_node2
->
primitive
->
value
.
type
=
schema
::
PrimitiveType_Mul
;
std
::
unique_ptr
<
schema
::
MulT
>
mul2_attr
=
std
::
make_unique
<
schema
::
MulT
>
();
mul_node2
->
primitive
->
value
.
value
=
mul2_attr
.
release
();
mul_node2
->
name
=
"mul2"
;
meta_graph
->
nodes
.
emplace_back
(
std
::
move
(
mul_node2
));
// input 0: data1
auto
input0
=
std
::
make_unique
<
schema
::
TensorT
>
();
input0
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input0
->
format
=
schema
::
Format_NHWC
;
input0
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input0
->
dims
=
{
1
,
2
,
2
,
3
};
input0
->
offset
=
-
1
;
auto
input0_data
=
new
(
std
::
nothrow
)
float
[
2
*
2
*
3
];
for
(
auto
i
=
0
;
i
<
2
*
2
*
3
;
i
++
)
{
input0_data
[
i
]
=
i
;
}
input0
->
data
.
resize
(
sizeof
(
float
)
*
2
*
2
*
3
);
memcpy
(
input0
->
data
.
data
(),
input0_data
,
2
*
2
*
3
*
sizeof
(
float
));
delete
[]
input0_data
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input0
));
// split output1
auto
split_output1
=
std
::
make_unique
<
schema
::
TensorT
>
();
split_output1
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
split_output1
->
format
=
schema
::
Format_NHWC
;
split_output1
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
split_output1
->
dims
=
{
1
,
1
,
2
,
3
};
split_output1
->
offset
=
-
1
;
split_output1
->
data
.
resize
(
sizeof
(
float
)
*
1
*
2
*
3
);
auto
split_output_data1
=
new
(
std
::
nothrow
)
float
[
1
*
2
*
3
];
memcpy
(
split_output1
->
data
.
data
(),
split_output_data1
,
1
*
2
*
3
*
sizeof
(
float
));
delete
[]
split_output_data1
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
split_output1
));
// split output2
auto
split_output2
=
std
::
make_unique
<
schema
::
TensorT
>
();
split_output2
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
split_output2
->
format
=
schema
::
Format_NHWC
;
split_output2
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
split_output2
->
dims
=
{
1
,
1
,
2
,
3
};
split_output2
->
offset
=
-
1
;
split_output2
->
data
.
resize
(
sizeof
(
float
)
*
1
*
2
*
3
);
auto
split_output_data2
=
new
(
std
::
nothrow
)
float
[
1
*
2
*
3
];
memcpy
(
split_output2
->
data
.
data
(),
split_output_data2
,
1
*
2
*
3
*
sizeof
(
float
));
delete
[]
split_output_data2
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
split_output2
));
// input 1: data2
auto
input1
=
std
::
make_unique
<
schema
::
TensorT
>
();
input1
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input1
->
format
=
schema
::
Format_NHWC
;
input1
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input1
->
dims
=
{
1
,
1
,
2
,
3
};
input1
->
offset
=
-
1
;
input1
->
data
.
resize
(
sizeof
(
float
)
*
2
*
3
);
auto
input1_data
=
new
(
std
::
nothrow
)
float
[
2
*
3
];
for
(
auto
i
=
0
;
i
<
2
*
3
;
i
++
)
{
input1_data
[
i
]
=
i
;
}
memcpy
(
input1
->
data
.
data
(),
input1_data
,
2
*
3
*
sizeof
(
float
));
delete
[]
input1_data
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input1
));
// input 2: data3
auto
input2
=
std
::
make_unique
<
schema
::
TensorT
>
();
input2
->
nodeType
=
schema
::
NodeType
::
NodeType_ValueNode
;
input2
->
format
=
schema
::
Format_NHWC
;
input2
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
input2
->
dims
=
{
1
,
1
,
2
,
3
};
input2
->
offset
=
-
1
;
input2
->
data
.
resize
(
sizeof
(
float
)
*
2
*
3
);
auto
input2_data
=
new
(
std
::
nothrow
)
float
[
2
*
3
];
for
(
auto
i
=
0
;
i
<
2
*
3
;
i
++
)
{
input2_data
[
i
]
=
10
;
}
memcpy
(
input2
->
data
.
data
(),
input2_data
,
2
*
3
*
sizeof
(
float
));
delete
[]
input2_data
;
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
input2
));
// final mul output1
auto
mul_output
=
std
::
make_unique
<
schema
::
TensorT
>
();
mul_output
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
mul_output
->
format
=
schema
::
Format_NHWC
;
mul_output
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
mul_output
->
dims
=
{
1
,
1
,
2
,
3
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
mul_output
));
// final mul output2
auto
mul_output2
=
std
::
make_unique
<
schema
::
TensorT
>
();
mul_output2
->
nodeType
=
schema
::
NodeType
::
NodeType_Parameter
;
mul_output2
->
format
=
schema
::
Format_NHWC
;
mul_output2
->
dataType
=
TypeId
::
kNumberTypeFloat32
;
mul_output2
->
dims
=
{
1
,
1
,
2
,
3
};
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
mul_output2
));
return
meta_graph
;
}
}
// namespace
TEST_F
(
ConstantFoldingFusionTest
,
TestADDConstantFold
)
{
auto
meta_graph
=
BuildGraph
(
schema
::
PrimitiveType_Add
,
new
schema
::
AddT
);
...
...
@@ -483,4 +613,19 @@ TEST_F(ConstantFoldingFusionTest, TestCastDimsConstantFold) {
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
0
);
}
TEST_F
(
ConstantFoldingFusionTest
,
TestSplitConstantFold
)
{
auto
meta_graph
=
BuildSplitGraph
();
auto
input_tensor
=
meta_graph
->
allTensors
.
at
(
0
).
get
();
input_tensor
->
dataType
=
kNumberTypeFloat32
;
auto
func_graph
=
lite
::
ModelParser
::
Fb2Anf
(
meta_graph
.
get
());
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
(
"test"
,
false
);
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ConstFoldPass
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
func_graph
);
ASSERT_NE
(
nullptr
,
new_graph
);
auto
new_meta_graph
=
lite
::
Export
(
new_graph
);
ASSERT_EQ
(
new_meta_graph
->
nodes
.
size
(),
0
);
}
}
// namespace mindspore
mindspore/lite/tools/optimizer/common/gllo_utils.cc
浏览文件 @
e121bcd3
...
...
@@ -319,7 +319,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
if
(
utils
::
isa
<
PrimitiveCPtr
>
(
value
))
{
auto
primitive
=
value
->
cast
<
PrimitiveCPtr
>
();
MS_ASSERT
(
primitive
!=
nullptr
);
return
(
schema
::
PrimitiveType
)
primitive
->
Type
();
return
(
schema
::
PrimitiveType
)
primitive
->
Type
();
}
else
if
(
utils
::
isa
<
Primitive
>
(
value
))
{
auto
primitive
=
value
->
cast
<
PrimitivePtr
>
();
MS_ASSERT
(
primitive
!=
nullptr
);
...
...
@@ -392,8 +392,8 @@ size_t GetOutputTensorNum(const AnfNodePtr &node) {
bool
IsMultiOutputTensors
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
)
{
auto
output_node_list
=
GetRealNodeUsedList
(
graph
,
node
);
if
(
output_node_list
->
size
()
!=
1
)
{
MS_LOG
(
DEBUG
)
<<
"fusion node has multi output nodes"
;
return
true
;
MS_LOG
(
DEBUG
)
<<
"fusion node has multi output nodes"
;
return
true
;
}
return
false
;
}
...
...
@@ -412,5 +412,50 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
std
::
copy
(
output_info_list
.
begin
(),
output_info_list
.
end
(),
std
::
back_inserter
(
*
output_node_list
));
return
output_node_list
;
}
size_t
GetTupleGetItemOutIndex
(
const
CNodePtr
&
tuple_get_item
)
{
MS_ASSERT
(
tuple_get_item
!=
nullptr
);
if
(
tuple_get_item
->
size
()
!=
kTupleGetItemInputSize
)
{
MS_LOG
(
ERROR
)
<<
"The node tuple_get_item must have 2 inputs!"
;
return
-
1
;
}
auto
output_index_value_node
=
tuple_get_item
->
input
(
kInputNodeOutputIndexInTupleGetItem
);
MS_ASSERT
(
output_index_value_node
!=
nullptr
);
auto
value_node
=
output_index_value_node
->
cast
<
ValueNodePtr
>
();
MS_ASSERT
(
value_node
!=
nullptr
);
return
IntToSize
(
GetValue
<
int
>
(
value_node
->
value
()));
}
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
int
>>>
GetRealNodeUsedListByOutputIdx
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
size_t
output_index
)
{
MS_ASSERT
(
graph
!=
nullptr
);
MS_ASSERT
(
node
!=
nullptr
);
auto
output_node_list
=
std
::
make_shared
<
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
int
>>>
();
auto
manager
=
graph
->
manager
();
MS_ASSERT
(
manager
!=
nullptr
);
auto
iter
=
manager
->
node_users
().
find
(
node
);
if
(
iter
==
manager
->
node_users
().
end
())
{
MS_LOG
(
ERROR
)
<<
"node has no output in manager"
;
return
output_node_list
;
}
auto
output_info_list
=
iter
->
second
;
for
(
const
auto
&
output_info
:
output_info_list
)
{
size_t
used_output_index
;
if
(
GetCNodeType
(
output_info
.
first
)
==
schema
::
PrimitiveType_TupleGetItem
)
{
used_output_index
=
GetTupleGetItemOutIndex
(
utils
::
cast
<
CNodePtr
>
(
output_info
.
first
));
}
else
if
(
GetCNodeType
(
node
)
==
schema
::
PrimitiveType_TupleGetItem
)
{
used_output_index
=
output_index
;
}
else
{
if
(
output_index
!=
0
)
{
MS_LOG
(
ERROR
)
<<
"node has no output in manager"
;
return
output_node_list
;
}
return
output_node_list
;
}
if
(
used_output_index
==
output_index
)
{
output_node_list
->
push_back
(
output_info
);
}
}
return
output_node_list
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/lite/tools/optimizer/common/gllo_utils.h
浏览文件 @
e121bcd3
...
...
@@ -63,6 +63,8 @@ bool CheckIsAllInputsParam(const AnfNodePtr &node);
size_t
GetOutputTensorNum
(
const
AnfNodePtr
&
node
);
bool
IsMultiOutputTensors
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
);
size_t
GetTupleGetItemOutIndex
(
const
CNodePtr
&
tuple_get_item
);
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc
浏览文件 @
e121bcd3
...
...
@@ -41,7 +41,7 @@ std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
auto
tensorT
=
tmp_meta_graph
->
allTensors
.
at
(
input_index
).
get
();
auto
tensor_shape
=
tensorT
->
dims
;
auto
lite_tensor
=
new
(
std
::
nothrow
)
Tensor
(
TypeId
(
tensorT
->
dataType
),
tensor_shape
,
tensorT
->
format
,
tensorT
->
nodeType
);
new
(
std
::
nothrow
)
Tensor
(
TypeId
(
tensorT
->
dataType
),
tensor_shape
,
tensorT
->
format
,
tensorT
->
nodeType
);
if
(
lite_tensor
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"lite tensor is nullptr"
;
return
input_tensors
;
...
...
@@ -106,7 +106,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
mindspore
::
lite
::
PrimitiveC
*
primitive
)
{
MS_ASSERT
(
nullptr
!=
lite_primitive
);
auto
data_type
=
inputs
.
front
()
->
data_type
();
kernel
::
KernelKey
desc
{
kernel
::
KERNEL_ARCH
::
kCPU
,
data_type
,
(
schema
::
PrimitiveType
)
primitive
->
Type
()};
kernel
::
KernelKey
desc
{
kernel
::
KERNEL_ARCH
::
kCPU
,
data_type
,
(
schema
::
PrimitiveType
)
primitive
->
Type
()};
lite
::
Context
context
;
auto
creator
=
lite
::
KernelRegistry
::
GetInstance
()
->
GetCreator
(
desc
);
if
(
creator
!=
nullptr
)
{
...
...
@@ -115,6 +115,44 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
}
return
nullptr
;
}
lite
::
STATUS
ReplaceCNode
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
any_node
,
const
AnfNodePtr
&
input_node
,
std
::
vector
<
Tensor
*>
output_tensors
,
size_t
replace_index
)
{
MS_ASSERT
(
func_graph
!=
nullptr
);
auto
manager
=
func_graph
->
manager
();
MS_ASSERT
(
manager
!=
nullptr
);
if
(
output_tensors
.
size
()
!=
1
)
{
for
(
size_t
k
=
0
;
k
<
output_tensors
.
size
();
k
++
)
{
auto
used_node_list
=
GetRealNodeUsedListByOutputIdx
(
func_graph
,
input_node
,
k
);
if
(
used_node_list
->
size
()
!=
1
)
{
MS_LOG
(
ERROR
)
<<
" output must tuple_getitem"
;
return
lite
::
RET_ERROR
;
}
auto
tuple_node
=
used_node_list
->
at
(
0
).
first
;
if
(
GetCNodeType
(
tuple_node
)
==
schema
::
PrimitiveType_TupleGetItem
)
{
auto
new_parameter
=
CreateNewParamter
(
func_graph
,
output_tensors
.
at
(
k
));
if
(
new_parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"CreateNewParamter failed, name: "
<<
input_node
->
fullname_with_scope
();
return
lite
::
RET_ERROR
;
}
new_parameter
->
set_name
(
input_node
->
fullname_with_scope
()
+
"_const_"
+
std
::
to_string
(
k
));
manager
->
Replace
(
tuple_node
,
new_parameter
);
}
else
{
MS_LOG
(
ERROR
)
<<
" multi out tensor must connect tuple-getitem: "
<<
input_node
->
fullname_with_scope
();
return
lite
::
RET_ERROR
;
}
}
}
else
{
auto
new_parameter
=
CreateNewParamter
(
func_graph
,
output_tensors
.
front
());
if
(
new_parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"CreateNewParamter failed, name: "
<<
input_node
->
fullname_with_scope
();
return
lite
::
RET_ERROR
;
}
new_parameter
->
set_name
(
input_node
->
fullname_with_scope
());
any_node
->
set_input
(
replace_index
,
new_parameter
);
}
return
lite
::
RET_OK
;
}
}
// namespace
void
FreeTensors
(
std
::
vector
<
Tensor
*>
*
input_tensor
,
std
::
vector
<
Tensor
*>
*
output_tensor
)
{
if
(
input_tensor
!=
nullptr
)
{
...
...
@@ -140,64 +178,66 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
}
auto
any_node
=
node
->
cast
<
CNodePtr
>
();
CheckIfCNodeIsNull
(
any_node
);
bool
changed
=
false
;
for
(
size_t
i
=
1
;
i
<
any_node
->
inputs
().
size
();
i
++
)
{
auto
input_node
=
any_node
->
input
(
i
);
if
(
input_node
->
isa
<
CNode
>
()
&&
CheckIsAllInputsParam
(
input_node
))
{
auto
input_cnode
=
input_node
->
cast
<
CNodePtr
>
();
auto
input_tensors
=
GetCNodeInputTensors
(
input_cnode
);
if
(
input_tensors
.
empty
()
||
input_tensors
.
size
()
!=
input_cnode
->
inputs
().
size
()
-
1
)
{
FreeTensors
(
&
input_tensors
,
nullptr
);
continue
;
}
MS_LOG
(
INFO
)
<<
"Begin fold node:"
<<
input_node
->
fullname_with_scope
();
auto
output_nums
=
GetOutputTensorNum
(
input_cnode
);
std
::
vector
<
Tensor
*>
output_tensors
{
output_nums
,
new
Tensor
()};
auto
lite_primitive
=
GetValueNode
<
std
::
shared_ptr
<
PrimitiveC
>>
(
input_cnode
->
input
(
0
));
if
(
lite_primitive
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"lite_primitive is nullptr"
;
FreeTensors
(
&
input_tensors
,
&
output_tensors
);
return
nullptr
;
}
// here, input_tensor's format need to be transposed nhwc according to fmkType,
// but for the time being, we only transpose the tensor with 0/1/2/3D.
// Others should be added in future.
for
(
size_t
j
=
0
;
j
<
input_tensors
.
size
();
++
j
)
{
input_tensors
[
j
]
->
SetFormat
(
schema
::
Format_NHWC
);
if
(
input_tensors
[
j
]
->
shape
().
size
()
==
4
)
{
MS_LOG
(
INFO
)
<<
"init input_tensor format to nhwc"
;
}
}
lite_primitive
->
InferShape
(
input_tensors
,
output_tensors
);
auto
parameter
=
kernel
::
PopulateParameter
(
lite_primitive
.
get
());
if
(
parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"PopulateParameter return nullptr, type: "
<<
schema
::
EnumNamePrimitiveType
((
schema
::
PrimitiveType
)(
lite_primitive
->
Type
()));
return
nullptr
;
}
auto
lite_kernel
=
GetLiteKernel
(
input_tensors
,
output_tensors
,
parameter
,
lite_primitive
.
get
());
if
(
lite_kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"constant_folding schedule node lite kernel nullptr"
;
FreeTensors
(
&
input_tensors
,
&
output_tensors
);
return
nullptr
;
}
auto
ret
=
lite_kernel
->
Run
();
if
(
0
!=
ret
)
{
FreeTensors
(
&
input_tensors
,
&
output_tensors
);
MS_LOG
(
ERROR
)
<<
"run kernel failed, name: "
<<
lite_kernel
->
name
();
return
nullptr
;
}
auto
new_parameter
=
CreateNewParamter
(
func_graph
,
output_tensors
.
front
());
if
(
new_parameter
==
nullptr
)
{
FreeTensors
(
&
input_tensors
,
&
output_tensors
);
MS_LOG
(
ERROR
)
<<
"CreateNewParamter failed, name: "
<<
lite_kernel
->
name
();
return
nullptr
;
if
(
!
input_node
->
isa
<
CNode
>
()
||
!
CheckIsAllInputsParam
(
input_node
))
{
continue
;
}
auto
input_cnode
=
input_node
->
cast
<
CNodePtr
>
();
auto
input_tensors
=
GetCNodeInputTensors
(
input_cnode
);
if
(
input_tensors
.
empty
()
||
input_tensors
.
size
()
!=
input_cnode
->
inputs
().
size
()
-
1
)
{
FreeTensors
(
&
input_tensors
,
nullptr
);
continue
;
}
changed
=
true
;
MS_LOG
(
INFO
)
<<
"Begin fold node:"
<<
input_node
->
fullname_with_scope
();
auto
output_nums
=
GetOutputTensorNum
(
input_cnode
);
std
::
vector
<
Tensor
*>
output_tensors
{
output_nums
,
new
Tensor
()};
auto
lite_primitive
=
GetValueNode
<
std
::
shared_ptr
<
PrimitiveC
>>
(
input_cnode
->
input
(
0
));
if
(
lite_primitive
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"lite_primitive is nullptr"
;
FreeTensors
(
&
input_tensors
,
&
output_tensors
);
return
nullptr
;
}
// here, input_tensor's format need to be transposed nhwc according to fmkType,
// but for the time being, we only transpose the tensor with 0/1/2/3D.
// Others should be added in future.
for
(
size_t
j
=
0
;
j
<
input_tensors
.
size
();
++
j
)
{
input_tensors
[
j
]
->
SetFormat
(
schema
::
Format_NHWC
);
if
(
input_tensors
[
j
]
->
shape
().
size
()
==
4
)
{
MS_LOG
(
INFO
)
<<
"init input_tensor format to nhwc"
;
}
new_parameter
->
set_name
(
input_node
->
fullname_with_scope
());
any_node
->
set_input
(
i
,
new_parameter
);
}
lite_primitive
->
InferShape
(
input_tensors
,
output_tensors
);
auto
parameter
=
kernel
::
PopulateParameter
(
lite_primitive
.
get
());
if
(
parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"PopulateParameter return nullptr, type: "
<<
schema
::
EnumNamePrimitiveType
((
schema
::
PrimitiveType
)
(
lite_primitive
->
Type
()));
return
nullptr
;
}
auto
lite_kernel
=
GetLiteKernel
(
input_tensors
,
output_tensors
,
parameter
,
lite_primitive
.
get
());
if
(
lite_kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"constant_folding schedule node lite kernel nullptr"
;
FreeTensors
(
&
input_tensors
,
&
output_tensors
);
return
nullptr
;
}
auto
ret
=
lite_kernel
->
Run
();
if
(
0
!=
ret
)
{
FreeTensors
(
&
input_tensors
,
&
output_tensors
);
MS_LOG
(
ERROR
)
<<
"run kernel failed, name: "
<<
lite_kernel
->
name
();
return
nullptr
;
}
// replace cnode by new param
if
(
ReplaceCNode
(
func_graph
,
any_node
,
input_node
,
output_tensors
,
i
)
!=
lite
::
RET_OK
)
{
FreeTensors
(
&
input_tensors
,
&
output_tensors
);
delete
(
lite_kernel
);
MS_LOG
(
ERROR
)
<<
"constant_folding replace cnode failed"
;
return
nullptr
;
}
FreeTensors
(
&
input_tensors
,
&
output_tensors
);
delete
(
lite_kernel
);
}
return
any_node
;
return
changed
?
any_node
:
nullptr
;
}
}
// namespace mindspore::opt
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录