Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2c63bdb1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2c63bdb1
编写于
8月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4595 process Constant op in onnx converter && add two onnx models
Merge pull request !4595 from wangzhe/master
上级
19192e75
c9cb994e
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
134 addition
and
81 deletion
+134
-81
mindspore/lite/test/models_onnx.cfg
mindspore/lite/test/models_onnx.cfg
+2
-0
mindspore/lite/test/run_benchmark_nets.sh
mindspore/lite/test/run_benchmark_nets.sh
+22
-0
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
...ore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
+103
-75
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
...pore/lite/tools/converter/parser/onnx/onnx_model_parser.h
+4
-1
mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc
...pore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc
+2
-2
mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc
...lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc
+1
-3
未找到文件。
mindspore/lite/test/models_onnx.cfg
浏览文件 @
2c63bdb1
mtk_detect-mbv2-shortcut-400-400-simplified.onnx
mtk_emotions-d2012-75.8%.onnx
mtk_face_features_v3.onnx
mindspore/lite/test/run_benchmark_nets.sh
浏览文件 @
2c63bdb1
...
...
@@ -153,6 +153,28 @@ function Run_arm64() {
fi
#sleep 1
done
<
${
models_caffe_config
}
# Run caffe converted models:
while
read
line
;
do
model_name
=
${
line
}
if
[[
$model_name
==
\#
*
]]
;
then
continue
fi
echo
${
model_name
}
echo
'cd /data/local/tmp/benchmark_test'
>
adb_run_cmd.txt
echo
'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelPath='
${
model_name
}
'.ms --inDataPath=/data/local/tmp/input_output/input/'
${
model_name
}
'.ms.bin --calibDataPath=/data/local/tmp/input_output/output/'
${
model_name
}
'.ms.out --warmUpLoopCount=1 --loopCount=1'
echo
'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelPath='
${
model_name
}
'.ms --inDataPath=/data/local/tmp/input_output/input/'
${
model_name
}
'.ms.bin --calibDataPath=/data/local/tmp/input_output/output/'
${
model_name
}
'.ms.out --warmUpLoopCount=1 --loopCount=1'
>>
adb_run_cmd.txt
adb
-s
${
device_id
}
shell < adb_run_cmd.txt
if
[
$?
=
0
]
;
then
run_result
=
'Run_arm64:'
${
model_name
}
' pass'
echo
${
run_result
}
>>
${
run_benchmark_result_file
}
else
run_result
=
'Run_arm64:'
${
model_name
}
' fail <<===========================this is the failed case'
echo
${
run_result
}
>>
${
run_benchmark_result_file
}
return
1
fi
#sleep 1
done
<
${
models_onnx_config
}
}
# Print start msg before run testcase
...
...
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
浏览文件 @
2c63bdb1
...
...
@@ -76,45 +76,85 @@ STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, go
STATUS
OnnxModelParser
::
SetGraphConstTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
TensorCache
*
tensor_cache
)
{
MS_LOG
(
DEBUG
)
<<
"set onnx constant tensors"
;
for
(
const
auto
&
onnx_const_value
:
onnx_graph
.
initializer
())
{
auto
data_type
=
GetDataTypeFromOnnx
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
onnx_const_value
.
data_type
()));
if
(
data_type
==
kTypeUnknown
)
{
MS_LOG
(
ERROR
)
<<
"not support onnx data type "
<<
static_cast
<
onnx
::
TensorProto_DataType
>
(
onnx_const_value
.
data_type
());
return
RET_ERROR
;
}
std
::
unique_ptr
<
schema
::
TensorT
>
tensor
(
new
(
std
::
nothrow
)
schema
::
TensorT
);
if
(
tensor
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new tensor failed"
;
return
RET_ERROR
;
}
tensor
->
dataType
=
data_type
;
tensor
->
format
=
schema
::
Format_NCHW
;
// onnx use NCHW
std
::
copy
(
onnx_const_value
.
dims
().
begin
(),
onnx_const_value
.
dims
().
end
(),
std
::
back_inserter
(
tensor
->
dims
));
tensor
->
nodeType
=
schema
::
NodeType_ValueNode
;
if
(
CopyOnnxTensorData
(
onnx_const_value
,
tensor
.
get
()))
{
MS_LOG
(
ERROR
)
<<
"copy onnx data failed"
;
return
RET_ERROR
;
int
index
;
const
auto
status
=
AddTensorProto
(
onnx_const_value
,
onnx_const_value
.
name
(),
GRAPH_INPUT
,
tensor_cache
,
&
index
);
if
(
status
!=
RET_OK
)
{
return
status
;
}
// TODO(wangzhe) why use GRAPH_INPUT other than CONST(GRAPH_INPUT will add index to graphInputs)
const
auto
index
=
tensor_cache
->
AddTensor
(
onnx_const_value
.
name
(),
tensor
.
release
(),
GRAPH_INPUT
);
MS_LOG
(
DEBUG
)
<<
"add const tensor: "
<<
onnx_const_value
.
name
()
<<
", index "
<<
index
;
}
MS_LOG
(
DEBUG
)
<<
"process onnx Constant ops"
;
for
(
int
i
=
0
;
i
<
onnx_graph
.
node_size
();
i
++
)
{
const
auto
&
node
=
onnx_graph
.
node
(
i
);
if
(
node
.
op_type
().
compare
(
"Constant"
)
==
0
)
{
for
(
const
auto
&
attr
:
node
.
attribute
())
{
if
(
attr
.
name
()
==
"sparse_value"
)
{
MS_LOG
(
ERROR
)
<<
"sparse_value"
;
}
if
(
attr
.
name
()
==
"value"
)
{
const
auto
&
t
=
attr
.
t
();
int
index
;
const
auto
status
=
AddTensorProto
(
t
,
node
.
output
(
0
),
GRAPH_INPUT
,
tensor_cache
,
&
index
);
if
(
status
!=
RET_OK
)
{
return
status
;
}
MS_LOG
(
DEBUG
)
<<
"add const tensor: "
<<
t
.
name
()
<<
", index "
<<
index
;
}
else
{
MS_LOG
(
ERROR
)
<<
"processing Constant op attr "
<<
attr
.
name
()
<<
" not implemented"
;
return
RET_INVALID_OP_ATTR
;
}
}
}
}
return
RET_OK
;
}
// TODO(wangzhe) seems AddTensorCache should be renamed to prepare tensor to add to tensor_cache
STATUS
OnnxModelParser
::
AddTensorCache
(
const
onnx
::
ValueInfoProto
&
proto
,
schema
::
TensorT
*
tensor
)
{
STATUS
OnnxModelParser
::
AddValueInfo
(
const
onnx
::
ValueInfoProto
&
proto
,
const
std
::
string
&
name
,
const
TensorType
&
type
,
TensorCache
*
tensor_cache
,
int
*
index
)
{
auto
data_type
=
GetDataTypeFromOnnx
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
proto
.
type
().
tensor_type
().
elem_type
()));
if
(
data_type
==
kTypeUnknown
)
{
MS_LOG
(
ERROR
)
<<
"not support onnx type "
MS_LOG
(
ERROR
)
<<
"not support onnx
data
type "
<<
static_cast
<
onnx
::
TensorProto_DataType
>
(
proto
.
type
().
tensor_type
().
elem_type
());
return
RET_ERROR
;
}
std
::
unique_ptr
<
schema
::
TensorT
>
tensor
(
new
schema
::
TensorT
);
if
(
tensor
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new tensor failed"
;
return
RET_ERROR
;
}
tensor
->
dataType
=
data_type
;
tensor
->
dims
=
GetDimsFromOnnxValue
(
proto
);
tensor
->
format
=
schema
::
Format_NCHW
;
tensor
->
nodeType
=
schema
::
NodeType_ValueNode
;
// TODO(wangzhe) tensor->data and quantParams not set, should we need tensor_cache->AddTensor?
*
index
=
tensor_cache
->
AddTensor
(
name
,
tensor
.
release
(),
type
);
return
RET_OK
;
}
STATUS
OnnxModelParser
::
AddTensorProto
(
const
onnx
::
TensorProto
&
proto
,
const
std
::
string
&
name
,
const
TensorType
&
type
,
TensorCache
*
tensor_cache
,
int
*
index
)
{
auto
data_type
=
GetDataTypeFromOnnx
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
proto
.
data_type
()));
if
(
data_type
==
kTypeUnknown
)
{
MS_LOG
(
ERROR
)
<<
"not support onnx data type "
<<
static_cast
<
onnx
::
TensorProto_DataType
>
(
proto
.
data_type
());
return
RET_ERROR
;
}
std
::
unique_ptr
<
schema
::
TensorT
>
tensor
(
new
(
std
::
nothrow
)
schema
::
TensorT
);
if
(
tensor
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new tensor failed"
;
return
RET_ERROR
;
}
tensor
->
dataType
=
data_type
;
std
::
copy
(
proto
.
dims
().
begin
(),
proto
.
dims
().
end
(),
std
::
back_inserter
(
tensor
->
dims
));
tensor
->
format
=
schema
::
Format_NCHW
;
tensor
->
nodeType
=
schema
::
NodeType_ValueNode
;
if
(
CopyOnnxTensorData
(
proto
,
tensor
.
get
()))
{
MS_LOG
(
ERROR
)
<<
"copy onnx data failed"
;
return
RET_ERROR
;
}
if
(
data_type
==
kNumberTypeInt64
)
{
MS_LOG
(
ERROR
)
<<
"INT64"
<<
proto
.
name
();
tensor
->
dataType
=
kNumberTypeInt32
;
// CopyOnnxTensorData will convert int64 to int32
}
*
index
=
tensor_cache
->
AddTensor
(
name
,
tensor
.
release
(),
type
);
return
RET_OK
;
}
...
...
@@ -123,15 +163,13 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
for
(
const
auto
&
input_value
:
onnx_graph
.
input
())
{
auto
ret
=
tensor_cache
->
FindTensor
(
input_value
.
name
());
if
(
ret
<
0
)
{
std
::
unique_ptr
<
schema
::
TensorT
>
tensor
(
new
schema
::
TensorT
)
;
// TODO(wangzhe) why there is an addtensorCache?
if
(
AddTensorCache
(
input_value
,
tensor
.
get
())
)
{
return
RET_ERROR
;
int
index
;
const
auto
status
=
AddValueInfo
(
input_value
,
input_value
.
name
(),
GRAPH_INPUT
,
tensor_cache
,
&
index
);
if
(
status
!=
RET_OK
)
{
return
status
;
}
// TODO(wangzhe) why inputTensor is value and should be added into tensor_cache?
auto
tensor_index
=
tensor_cache
->
AddTensor
(
input_value
.
name
(),
tensor
.
release
(),
GRAPH_INPUT
);
graph
->
inputIndex
.
emplace_back
(
static_cast
<
uint32_t
>
(
tensor_index
));
MS_LOG
(
DEBUG
)
<<
"input_value name: "
<<
input_value
.
name
()
<<
", graph input index: "
<<
tensor_index
;
MS_LOG
(
ERROR
)
<<
"input_value name: "
<<
input_value
.
name
()
<<
", graph input index: "
<<
index
;
graph
->
inputIndex
.
emplace_back
(
static_cast
<
uint32_t
>
(
index
));
}
}
return
RET_OK
;
...
...
@@ -140,14 +178,13 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
STATUS
OnnxModelParser
::
SetGraphOutputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
)
{
for
(
const
auto
&
output_value
:
onnx_graph
.
output
())
{
std
::
unique_ptr
<
schema
::
TensorT
>
tensor
(
new
schema
::
TensorT
);
if
(
AddTensorCache
(
output_value
,
tensor
.
get
()))
{
return
RET_ERROR
;
int
index
;
const
auto
status
=
AddValueInfo
(
output_value
,
output_value
.
name
(),
OP_OUTPUT
,
tensor_cache
,
&
index
);
if
(
status
!=
RET_OK
)
{
return
status
;
}
// TODO(wangzhe) why we need AddTensor at OutputTensor
auto
tensor_index
=
tensor_cache
->
AddTensor
(
output_value
.
name
(),
tensor
.
release
(),
OP_OUTPUT
);
graph
->
outputIndex
.
emplace_back
(
tensor_index
);
MS_LOG
(
DEBUG
)
<<
"output_value name: "
<<
output_value
.
name
()
<<
", graph output index: "
<<
tensor_index
;
graph
->
outputIndex
.
emplace_back
(
index
);
MS_LOG
(
ERROR
)
<<
"output_value name: "
<<
output_value
.
name
()
<<
", graph output index: "
<<
index
;
}
return
RET_OK
;
}
...
...
@@ -332,32 +369,11 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co
STATUS
OnnxModelParser
::
SetOpInputIndex
(
const
std
::
vector
<
string
>
&
node_inputs
,
schema
::
CNodeT
*
dst_op
,
const
onnx
::
NodeProto
&
onnx_node
,
TensorCache
*
tensor_cache
)
{
schema
::
Format
format
=
schema
::
Format_MAX
;
for
(
const
auto
&
onnx_node_attr
:
onnx_node
.
attribute
())
{
if
(
onnx_node_attr
.
name
()
==
"order"
)
{
// do we need this code? onnx doc don't have order attr
MS_LOG
(
EXCEPTION
)
<<
"find order attr"
;
if
(
onnx_node_attr
.
s
()
==
"NHWC"
)
{
format
=
schema
::
Format_NHWC
;
}
else
{
MS_LOG
(
ERROR
)
<<
"Unsupported format: "
<<
onnx_node_attr
.
s
();
return
RET_ERROR
;
}
}
}
for
(
const
auto
&
onnx_node_input
:
node_inputs
)
{
auto
index
=
tensor_cache
->
FindTensor
(
onnx_node_input
);
// MS_LOG(ERROR) << onnx_node.name() << " input " << onnx_node_input << " index in tensor_cache " << index;
if
(
index
<
0
)
{
// TODO(wangzhe) can this be ignored? because it's no use
/*
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT);
index = tensor_cache->AddTensor(onnx_node_input, tensor.release(), OP_OUTPUT);
*/
MS_LOG
(
EXCEPTION
)
<<
"input "
<<
onnx_node_input
<<
" of node "
<<
onnx_node
.
name
()
<<
" can't be found"
;
// MS_LOG(INFO) << "new index: " << index;
}
if
(
format
!=
schema
::
Format_MAX
)
{
// TODO(wangzhe) also this
auto
inTensor
=
tensor_cache
->
GetCachedTensor
().
at
(
index
);
inTensor
->
format
=
format
;
if
(
index
<
0
)
{
MS_LOG
(
ERROR
)
<<
"input "
<<
onnx_node_input
<<
" of node "
<<
onnx_node
.
name
()
<<
" can't be found"
;
return
RET_ERROR
;
}
MS_LOG
(
DEBUG
)
<<
"node: "
<<
onnx_node_input
<<
", input index: "
<<
index
;
dst_op
->
inputIndex
.
emplace_back
(
index
);
...
...
@@ -369,19 +385,12 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
TensorCache
*
tensor_cache
)
{
for
(
const
auto
&
onnx_node_output
:
node_outputs
)
{
auto
index
=
tensor_cache
->
FindTensor
(
onnx_node_output
);
if
(
index
<
0
)
{
MS_LOG
(
INFO
)
<<
"output of node "
<<
dst_op
->
name
<<
" not in tensor_cache, creating"
;
MS_LOG
(
INFO
)
<<
"total "
<<
node_outputs
.
size
()
<<
" outputs"
;
if
(
index
<
0
)
{
// when index >= 0, it's graph's output
std
::
unique_ptr
<
schema
::
TensorT
>
tensor
(
new
schema
::
TensorT
);
// GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
// tensor->dataType = ;
// tensor->dims = tflite_tensor->shape;
tensor
->
nodeType
=
schema
::
NodeType_Parameter
;
index
=
tensor_cache
->
AddTensor
(
onnx_node_output
,
tensor
.
release
(),
OP_OUTPUT
);
}
MS_LOG
(
DEBUG
)
<<
"node: "
<<
onnx_node_output
<<
",
in
put index: "
<<
index
;
MS_LOG
(
DEBUG
)
<<
"node: "
<<
onnx_node_output
<<
",
out
put index: "
<<
index
;
dst_op
->
outputIndex
.
emplace_back
(
index
);
}
return
RET_OK
;
...
...
@@ -390,8 +399,10 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
STATUS
OnnxModelParser
::
CopyOnnxTensorData
(
const
onnx
::
TensorProto
&
onnx_const_value
,
schema
::
TensorT
*
tensor
)
{
size_t
data_count
=
1
;
std
::
for_each
(
tensor
->
dims
.
begin
(),
tensor
->
dims
.
end
(),
[
&
data_count
](
int
dim
)
{
data_count
*=
dim
;
});
MS_LOG
(
ERROR
)
<<
"const tensor dims "
<<
tensor
->
dims
.
size
();
size_t
data_size
=
0
;
const
void
*
tensor_data
=
nullptr
;
int32_t
*
buffer
=
nullptr
;
switch
(
tensor
->
dataType
)
{
case
kNumberTypeFloat32
:
data_size
=
data_count
*
sizeof
(
float
);
...
...
@@ -410,12 +421,23 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
}
break
;
case
kNumberTypeInt64
:
data_size
=
data_count
*
sizeof
(
int64_t
);
data_size
=
data_count
*
sizeof
(
int32_t
);
buffer
=
new
int32_t
[
data_count
];
const
int64_t
*
in_data
;
if
(
onnx_const_value
.
int64_data_size
()
==
0
)
{
tensor_data
=
onnx_const_value
.
raw_data
().
data
(
);
in_data
=
reinterpret_cast
<
const
int64_t
*>
(
onnx_const_value
.
raw_data
().
data
()
);
}
else
{
tensor
_data
=
onnx_const_value
.
int64_data
().
data
();
in
_data
=
onnx_const_value
.
int64_data
().
data
();
}
for
(
int
i
=
0
;
i
<
data_count
;
++
i
)
{
if
(
in_data
[
i
]
>
static_cast
<
int64_t
>
(
INT32_MAX
)
||
in_data
[
i
]
<
static_cast
<
int64_t
>
(
INT32_MIN
))
{
MS_LOG
(
ERROR
)
<<
"int64 data "
<<
in_data
[
i
]
<<
"too big to fit into int32"
;
return
RET_ERROR
;
}
else
{
buffer
[
i
]
=
static_cast
<
int
>
(
in_data
[
i
]);
}
}
tensor_data
=
reinterpret_cast
<
void
*>
(
buffer
);
break
;
case
kNumberTypeUInt8
:
case
kNumberTypeInt8
:
...
...
@@ -431,6 +453,9 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
MS_LOG
(
ERROR
)
<<
"memcpy_s failed"
;
return
RET_ERROR
;
}
if
(
kNumberTypeInt64
==
tensor
->
dataType
)
{
free
(
buffer
);
}
return
RET_OK
;
}
...
...
@@ -491,6 +516,9 @@ MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::stri
}
// init op node input/output tensor, and dst_op attr
for
(
const
auto
&
onnx_node
:
onnx_graph
.
node
())
{
if
(
onnx_node
.
op_type
()
==
"Constant"
)
{
continue
;
}
if
(
onnx_node
.
op_type
()
==
"Gemm"
)
{
ParseOnnxGemmNode
(
onnx_graph
,
onnx_node
,
dst_graph
.
get
(),
&
tensor_cache
);
continue
;
...
...
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
浏览文件 @
2c63bdb1
...
...
@@ -48,7 +48,10 @@ class OnnxModelParser : public ModelParser {
STATUS
SetGraphConstTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
TensorCache
*
tensor_cache
);
STATUS
SetGraphInputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
);
STATUS
SetGraphOutputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
);
STATUS
AddTensorCache
(
const
onnx
::
ValueInfoProto
&
proto
,
schema
::
TensorT
*
tensor
);
STATUS
AddValueInfo
(
const
onnx
::
ValueInfoProto
&
proto
,
const
std
::
string
&
name
,
const
TensorType
&
type
,
TensorCache
*
tensor_cache
,
int
*
index
);
STATUS
AddTensorProto
(
const
onnx
::
TensorProto
&
proto
,
const
std
::
string
&
name
,
const
TensorType
&
type
,
TensorCache
*
tensor_cache
,
int
*
index
);
STATUS
ParseOnnxNodeToDstOp
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
schema
::
CNodeT
*
dst_op
,
schema
::
TensorT
*
dst_tensor
,
TensorCache
*
tensor_cache
);
void
ParseOnnxGemmNode
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
...
...
mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc
浏览文件 @
2c63bdb1
...
...
@@ -23,6 +23,7 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
MS_LOG
(
DEBUG
)
<<
"onnx PoolParser"
;
std
::
unique_ptr
<
schema
::
PoolingT
>
attr
(
new
schema
::
PoolingT
());
attr
->
format
=
schema
::
Format_NCHW
;
const
auto
&
pool_type
=
onnx_node
.
op_type
();
if
(
pool_type
==
"MaxPool"
)
{
attr
->
poolingMode
=
schema
::
PoolMode_MAX_POOLING
;
...
...
@@ -37,7 +38,7 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr
->
poolingMode
=
schema
::
PoolMode_MEAN_POOLING
;
attr
->
global
=
true
;
}
else
{
// MS_LOGE("Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only.")
;
MS_LOG
(
ERROR
)
<<
"Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."
;
return
RET_ERROR
;
}
...
...
@@ -92,4 +93,3 @@ OnnxNodeRegistrar g_onnxGlobalAveragePoolParser("GlobalAveragePool", new OnnxPoo
OnnxNodeRegistrar
g_onnxGlobalMaxPoolParser
(
"GlobalMaxPool"
,
new
OnnxPoolParser
());
}
// namespace lite
}
// namespace mindspore
mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc
浏览文件 @
2c63bdb1
...
...
@@ -19,8 +19,7 @@
namespace
mindspore
{
namespace
lite
{
STATUS
OnnxUnSqueezeParser
::
Parse
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
STATUS
OnnxUnSqueezeParser
::
Parse
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
schema
::
CNodeT
*
op
)
{
MS_LOG
(
DEBUG
)
<<
"onnx UnSqueezeParser"
;
std
::
unique_ptr
<
schema
::
UnsqueezeT
>
attr
(
new
schema
::
UnsqueezeT
());
...
...
@@ -43,4 +42,3 @@ STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph,
OnnxNodeRegistrar
g_onnxUnsqueezeParser
(
"Unsqueeze"
,
new
OnnxUnSqueezeParser
());
}
// namespace lite
}
// namespace mindspore
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录