Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
67f954a5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
67f954a5
编写于
8月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5455 fix onnx | mindir read protobuf bug in windows
Merge pull request !5455 from hangq/master
上级
fa50dd02
5a6c358d
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
88 addition
and
205 deletion
+88
-205
mindspore/lite/test/CMakeLists.txt
mindspore/lite/test/CMakeLists.txt
+1
-0
mindspore/lite/tools/anf_importer/import_from_protobuf.cc
mindspore/lite/tools/anf_importer/import_from_protobuf.cc
+3
-23
mindspore/lite/tools/common/protobuf_utils.cc
mindspore/lite/tools/common/protobuf_utils.cc
+4
-7
mindspore/lite/tools/common/protobuf_utils.h
mindspore/lite/tools/common/protobuf_utils.h
+2
-5
mindspore/lite/tools/converter/CMakeLists.txt
mindspore/lite/tools/converter/CMakeLists.txt
+1
-0
mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt
mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt
+0
-1
mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc
...pore/lite/tools/converter/parser/caffe/caffe_converter.cc
+0
-1
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc
...e/lite/tools/converter/parser/caffe/caffe_model_parser.cc
+14
-23
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
...ore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
+28
-79
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
...pore/lite/tools/converter/parser/onnx/onnx_model_parser.h
+35
-66
未找到文件。
mindspore/lite/test/CMakeLists.txt
浏览文件 @
67f954a5
...
@@ -230,6 +230,7 @@ if(BUILD_CONVERTER)
...
@@ -230,6 +230,7 @@ if(BUILD_CONVERTER)
${
TEST_LITE_SRC
}
${
TEST_LITE_SRC
}
${
TEST_CASE_TFLITE_PARSERS_SRC
}
${
TEST_CASE_TFLITE_PARSERS_SRC
}
${
TOP_DIR
}
/mindspore/core/utils/flags.cc
${
TOP_DIR
}
/mindspore/core/utils/flags.cc
${
LITE_DIR
}
/tools/common/protobuf_utils.cc
${
LITE_DIR
}
/tools/converter/optimizer.cc
${
LITE_DIR
}
/tools/converter/optimizer.cc
${
LITE_DIR
}
/tools/converter/anf_transform.cc
${
LITE_DIR
}
/tools/converter/anf_transform.cc
${
LITE_DIR
}
/tools/converter/graphdef_transform.cc
${
LITE_DIR
}
/tools/converter/graphdef_transform.cc
...
...
mindspore/lite/tools/anf_importer/import_from_protobuf.cc
浏览文件 @
67f954a5
...
@@ -27,7 +27,6 @@
...
@@ -27,7 +27,6 @@
#include <vector>
#include <vector>
#include "src/ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#include "frontend/operator/ops.h"
#include "frontend/operator/ops.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "include/errorcode.h"
#include "include/errorcode.h"
#include "ir/anf.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/func_graph.h"
...
@@ -37,6 +36,7 @@
...
@@ -37,6 +36,7 @@
#include "src/param_value_lite.h"
#include "src/param_value_lite.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "utils/log_adapter.h"
#include "utils/log_adapter.h"
#include "tools/common/protobuf_utils.h"
using
string
=
std
::
string
;
using
string
=
std
::
string
;
using
int32
=
int32_t
;
using
int32
=
int32_t
;
...
@@ -651,31 +651,11 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
...
@@ -651,31 +651,11 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
}
}
onnx
::
ModelProto
*
AnfImporterFromProtobuf
::
ReadOnnxFromBinary
(
const
std
::
string
&
model_path
)
{
onnx
::
ModelProto
*
AnfImporterFromProtobuf
::
ReadOnnxFromBinary
(
const
std
::
string
&
model_path
)
{
std
::
unique_ptr
<
char
[]
>
onnx_file
(
new
(
std
::
nothrow
)
char
[
PATH_MAX
]{
0
});
#ifdef _WIN32
if
(
_fullpath
(
onnx_file
.
get
(),
model_path
.
c_str
(),
1024
)
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"open file failed."
;
return
nullptr
;
}
#else
if
(
realpath
(
model_path
.
c_str
(),
onnx_file
.
get
())
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"open file failed."
;
return
nullptr
;
}
#endif
int
fd
=
open
(
onnx_file
.
get
(),
O_RDONLY
);
google
::
protobuf
::
io
::
FileInputStream
input
(
fd
);
google
::
protobuf
::
io
::
CodedInputStream
code_input
(
&
input
);
code_input
.
SetTotalBytesLimit
(
INT_MAX
,
536870912
);
auto
onnx_model
=
new
onnx
::
ModelProto
;
auto
onnx_model
=
new
onnx
::
ModelProto
;
bool
ret
=
onnx_model
->
ParseFromCodedStream
(
&
code_input
);
if
(
ReadProtoFromBinaryFile
((
const
char
*
)
model_path
.
c_str
(),
onnx_model
)
!=
RET_OK
)
{
if
(
!
ret
)
{
MS_LOG
(
ERROR
)
<<
"Read onnx model file failed, model path: "
<<
model_path
;
MS_LOG
(
ERROR
)
<<
"load onnx file failed"
;
delete
onnx_model
;
return
nullptr
;
return
nullptr
;
}
}
(
void
)
close
(
fd
);
MS_LOG
(
INFO
)
<<
"enter ReadProtoFromBinary success!"
<<
std
::
endl
;
return
onnx_model
;
return
onnx_model
;
}
}
...
...
mindspore/lite/tools/co
nverter/parser/caffe/caffe_parse
_utils.cc
→
mindspore/lite/tools/co
mmon/protobuf
_utils.cc
浏览文件 @
67f954a5
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include "
mindspore/lite/tools/converter/parser/caffe/caffe_parse
_utils.h"
#include "
tools/common/protobuf
_utils.h"
#include <fstream>
#include <fstream>
#include <string>
#include <string>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
...
@@ -37,15 +37,14 @@ bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded
...
@@ -37,15 +37,14 @@ bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded
return
proto
->
ParseFromCodedStream
(
coded_stream
);
return
proto
->
ParseFromCodedStream
(
coded_stream
);
}
}
STATUS
ReadProtoFromText
(
const
char
*
file
,
STATUS
ReadProtoFromText
(
const
char
*
file
,
google
::
protobuf
::
Message
*
message
)
{
google
::
protobuf
::
Message
*
message
)
{
if
(
file
==
nullptr
||
message
==
nullptr
)
{
if
(
file
==
nullptr
||
message
==
nullptr
)
{
return
RET_ERROR
;
return
RET_ERROR
;
}
}
std
::
string
realPath
=
RealPath
(
file
);
std
::
string
realPath
=
RealPath
(
file
);
if
(
realPath
.
empty
())
{
if
(
realPath
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"Proto file path "
<<
file
<<
" is not valid"
;
MS_LOG
(
ERROR
)
<<
"Proto file path "
<<
file
<<
" is not valid"
;
return
RET_ERROR
;
return
RET_ERROR
;
}
}
...
@@ -67,8 +66,7 @@ STATUS ReadProtoFromText(const char *file,
...
@@ -67,8 +66,7 @@ STATUS ReadProtoFromText(const char *file,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
ReadProtoFromBinaryFile
(
const
char
*
file
,
STATUS
ReadProtoFromBinaryFile
(
const
char
*
file
,
google
::
protobuf
::
Message
*
message
)
{
google
::
protobuf
::
Message
*
message
)
{
if
(
file
==
nullptr
||
message
==
nullptr
)
{
if
(
file
==
nullptr
||
message
==
nullptr
)
{
return
RET_ERROR
;
return
RET_ERROR
;
}
}
...
@@ -100,4 +98,3 @@ STATUS ReadProtoFromBinaryFile(const char *file,
...
@@ -100,4 +98,3 @@ STATUS ReadProtoFromBinaryFile(const char *file,
}
}
}
// namespace lite
}
// namespace lite
}
// namespace mindspore
}
// namespace mindspore
mindspore/lite/tools/co
nverter/parser/caffe/caffe_parse
_utils.h
→
mindspore/lite/tools/co
mmon/protobuf
_utils.h
浏览文件 @
67f954a5
...
@@ -29,13 +29,10 @@ namespace lite {
...
@@ -29,13 +29,10 @@ namespace lite {
bool
ReadProtoFromCodedInputStream
(
google
::
protobuf
::
io
::
CodedInputStream
*
coded_stream
,
bool
ReadProtoFromCodedInputStream
(
google
::
protobuf
::
io
::
CodedInputStream
*
coded_stream
,
google
::
protobuf
::
Message
*
proto
);
google
::
protobuf
::
Message
*
proto
);
STATUS
ReadProtoFromText
(
const
char
*
file
,
STATUS
ReadProtoFromText
(
const
char
*
file
,
google
::
protobuf
::
Message
*
message
);
google
::
protobuf
::
Message
*
message
);
STATUS
ReadProtoFromBinaryFile
(
const
char
*
file
,
STATUS
ReadProtoFromBinaryFile
(
const
char
*
file
,
google
::
protobuf
::
Message
*
message
);
google
::
protobuf
::
Message
*
message
);
}
// namespace lite
}
// namespace lite
}
// namespace mindspore
}
// namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_
mindspore/lite/tools/converter/CMakeLists.txt
浏览文件 @
67f954a5
...
@@ -94,6 +94,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
...
@@ -94,6 +94,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/graph_util.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/graph_util.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/node_util.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/node_util.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/tensor_util.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/tensor_util.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/protobuf_utils.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/flag_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/flag_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/storage.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../common/storage.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/ir/primitive_t_value.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../src/ir/primitive_t_value.cc
...
...
mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt
浏览文件 @
67f954a5
...
@@ -15,7 +15,6 @@ add_library(caffe_parser_mid OBJECT
...
@@ -15,7 +15,6 @@ add_library(caffe_parser_mid OBJECT
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_model_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_model_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_node_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_node_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_node_parser_registry.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_node_parser_registry.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_parse_utils.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_pooling_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_pooling_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_power_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_power_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_prelu_parser.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/caffe_prelu_parser.cc
...
...
mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc
浏览文件 @
67f954a5
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
*/
*/
#include "mindspore/lite/tools/converter/parser/caffe/caffe_converter.h"
#include "mindspore/lite/tools/converter/parser/caffe/caffe_converter.h"
#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
lite
{
namespace
lite
{
...
...
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc
浏览文件 @
67f954a5
...
@@ -14,14 +14,14 @@
...
@@ -14,14 +14,14 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include "
mindspore/lite/
tools/converter/parser/caffe/caffe_model_parser.h"
#include "tools/converter/parser/caffe/caffe_model_parser.h"
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
#include <utility>
#include <utility>
#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h"
#include "tools/converter/parser/caffe/caffe_node_parser_registry.h"
#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h"
#include "tools/converter/parser/caffe/caffe_inspector.h"
#include "mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h"
#include "tools/common/graph_util.h"
#include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
lite
{
namespace
lite
{
...
@@ -31,9 +31,8 @@ CaffeModelParser::~CaffeModelParser() {}
...
@@ -31,9 +31,8 @@ CaffeModelParser::~CaffeModelParser() {}
const
std
::
set
<
std
::
string
>
CaffeModelParser
::
skipedLayerType
=
{
"Dropout"
};
const
std
::
set
<
std
::
string
>
CaffeModelParser
::
skipedLayerType
=
{
"Dropout"
};
schema
::
MetaGraphT
*
CaffeModelParser
::
ParseToFb
(
const
std
::
string
&
modelFile
,
schema
::
MetaGraphT
*
CaffeModelParser
::
ParseToFb
(
const
std
::
string
&
modelFile
,
const
std
::
string
&
weightFile
,
const
std
::
string
&
weightFile
,
const
QuantType
&
quantType
)
{
const
QuantType
&
quantType
)
{
if
(
ValidateFileStr
(
modelFile
,
".prototxt"
)
!=
RET_OK
)
{
if
(
ValidateFileStr
(
modelFile
,
".prototxt"
)
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"INPUT ILLEGAL: modelFile must be *.prototxt"
;
MS_LOG
(
ERROR
)
<<
"INPUT ILLEGAL: modelFile must be *.prototxt"
;
return
nullptr
;
return
nullptr
;
...
@@ -89,8 +88,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile,
...
@@ -89,8 +88,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile,
return
metaGraph
.
release
();
return
metaGraph
.
release
();
}
}
STATUS
CaffeModelParser
::
SetOpInputIdx
(
const
caffe
::
LayerParameter
&
layer
,
STATUS
CaffeModelParser
::
SetOpInputIdx
(
const
caffe
::
LayerParameter
&
layer
,
schema
::
CNodeT
*
op
,
schema
::
CNodeT
*
op
,
TensorCache
*
tensorCache
)
{
TensorCache
*
tensorCache
)
{
for
(
int
i
=
0
;
i
<
layer
.
bottom_size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
layer
.
bottom_size
();
i
++
)
{
int
index
=
tensorCache
->
FindTensor
(
layer
.
bottom
(
i
));
int
index
=
tensorCache
->
FindTensor
(
layer
.
bottom
(
i
));
...
@@ -104,8 +102,7 @@ STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer,
...
@@ -104,8 +102,7 @@ STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
CaffeModelParser
::
SetOpOutputIdx
(
const
caffe
::
LayerParameter
&
layer
,
STATUS
CaffeModelParser
::
SetOpOutputIdx
(
const
caffe
::
LayerParameter
&
layer
,
schema
::
CNodeT
*
op
,
schema
::
CNodeT
*
op
,
TensorCache
*
tensorCache
)
{
TensorCache
*
tensorCache
)
{
for
(
int
i
=
0
;
i
<
layer
.
top_size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
layer
.
top_size
();
i
++
)
{
std
::
unique_ptr
<
schema
::
TensorT
>
msTensor
=
std
::
make_unique
<
schema
::
TensorT
>
();
std
::
unique_ptr
<
schema
::
TensorT
>
msTensor
=
std
::
make_unique
<
schema
::
TensorT
>
();
...
@@ -114,8 +111,7 @@ STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer,
...
@@ -114,8 +111,7 @@ STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
CaffeModelParser
::
SetWeightTensor
(
const
std
::
vector
<
schema
::
TensorT
*>
&
weightVec
,
STATUS
CaffeModelParser
::
SetWeightTensor
(
const
std
::
vector
<
schema
::
TensorT
*>
&
weightVec
,
schema
::
CNodeT
*
op
,
schema
::
CNodeT
*
op
,
TensorCache
*
tensorCache
)
{
TensorCache
*
tensorCache
)
{
for
(
auto
iter
:
weightVec
)
{
for
(
auto
iter
:
weightVec
)
{
op
->
inputIndex
.
emplace_back
(
tensorCache
->
AddTensor
(
"Weight"
,
iter
,
CONST
));
op
->
inputIndex
.
emplace_back
(
tensorCache
->
AddTensor
(
"Weight"
,
iter
,
CONST
));
...
@@ -123,8 +119,7 @@ STATUS CaffeModelParser::SetWeightTensor(const std::vector<schema::TensorT *> &w
...
@@ -123,8 +119,7 @@ STATUS CaffeModelParser::SetWeightTensor(const std::vector<schema::TensorT *> &w
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
CaffeModelParser
::
SetAllTensors
(
const
TensorCache
&
tensorCache
,
STATUS
CaffeModelParser
::
SetAllTensors
(
const
TensorCache
&
tensorCache
,
schema
::
MetaGraphT
*
subGraphDef
)
{
schema
::
MetaGraphT
*
subGraphDef
)
{
std
::
vector
<
schema
::
TensorT
*>
tensors
=
tensorCache
.
GetCachedTensor
();
std
::
vector
<
schema
::
TensorT
*>
tensors
=
tensorCache
.
GetCachedTensor
();
for
(
auto
iter
:
tensors
)
{
for
(
auto
iter
:
tensors
)
{
std
::
unique_ptr
<
schema
::
TensorT
>
temp
(
iter
);
std
::
unique_ptr
<
schema
::
TensorT
>
temp
(
iter
);
...
@@ -133,8 +128,7 @@ STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache,
...
@@ -133,8 +128,7 @@ STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
CaffeModelParser
::
SetGraphTensorIndex
(
const
caffe
::
NetParameter
&
proto
,
STATUS
CaffeModelParser
::
SetGraphTensorIndex
(
const
caffe
::
NetParameter
&
proto
,
TensorCache
*
tensorCache
,
TensorCache
*
tensorCache
,
schema
::
MetaGraphT
*
subGraphDef
)
{
schema
::
MetaGraphT
*
subGraphDef
)
{
CaffeInspector
caffeInspector
;
CaffeInspector
caffeInspector
;
caffeInspector
.
InspectModel
(
proto
);
caffeInspector
.
InspectModel
(
proto
);
...
@@ -160,10 +154,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto,
...
@@ -160,10 +154,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
CaffeModelParser
::
ParseLayer
(
const
caffe
::
NetParameter
&
proto
,
STATUS
CaffeModelParser
::
ParseLayer
(
const
caffe
::
NetParameter
&
proto
,
const
caffe
::
NetParameter
&
weight
,
const
caffe
::
NetParameter
&
weight
,
TensorCache
*
tensorCache
,
schema
::
MetaGraphT
*
subGraphDef
)
{
TensorCache
*
tensorCache
,
schema
::
MetaGraphT
*
subGraphDef
)
{
for
(
int
i
=
0
;
i
<
proto
.
layer_size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
proto
.
layer_size
();
i
++
)
{
auto
layer
=
proto
.
layer
(
i
);
auto
layer
=
proto
.
layer
(
i
);
...
@@ -235,8 +227,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto,
...
@@ -235,8 +227,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
CaffeModelParser
::
GetModelInput
(
const
caffe
::
NetParameter
&
proto
,
STATUS
CaffeModelParser
::
GetModelInput
(
const
caffe
::
NetParameter
&
proto
,
TensorCache
*
tensorCache
)
{
TensorCache
*
tensorCache
)
{
for
(
int
i
=
0
;
i
<
proto
.
input_size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
proto
.
input_size
();
i
++
)
{
if
(
proto
.
input_dim_size
()
<=
0
)
{
if
(
proto
.
input_dim_size
()
<=
0
)
{
continue
;
continue
;
...
...
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
100755 → 100644
浏览文件 @
67f954a5
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include <utility>
#include <utility>
#include "tools/common/graph_util.h"
#include "tools/common/graph_util.h"
#include "src/common/utils.h"
#include "src/common/utils.h"
#include "tools/common/protobuf_utils.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
lite
{
namespace
lite
{
...
@@ -54,36 +55,7 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo
...
@@ -54,36 +55,7 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo
return
dims
;
return
dims
;
}
}
STATUS
OnnxModelParser
::
ReadOnnxModelFromBinary
(
const
std
::
string
&
modelFile
,
STATUS
OnnxModelParser
::
SetGraphConstTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
TensorCache
*
tensor_cache
)
{
google
::
protobuf
::
Message
*
onnx_model
)
{
std
::
unique_ptr
<
char
>
onnx_file
(
new
(
std
::
nothrow
)
char
[
PATH_MAX
]{
0
});
#ifdef _WIN32
if
(
_fullpath
(
onnx_file
.
get
(),
modelFile
.
c_str
(),
1024
)
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"get realpath "
<<
modelFile
<<
" fail"
;
return
RET_ERROR
;
}
#else
if
(
realpath
(
modelFile
.
c_str
(),
onnx_file
.
get
())
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"get realpath "
<<
modelFile
<<
" fail"
;
return
RET_ERROR
;
}
#endif
int
fd
=
open
(
onnx_file
.
get
(),
O_RDONLY
);
google
::
protobuf
::
io
::
FileInputStream
input
(
fd
);
google
::
protobuf
::
io
::
CodedInputStream
code_input
(
&
input
);
code_input
.
SetTotalBytesLimit
(
INT_MAX
,
536870912
);
bool
ret
=
onnx_model
->
ParseFromCodedStream
(
&
code_input
);
if
(
!
ret
)
{
MS_LOG
(
ERROR
)
<<
"load onnx file failed"
;
return
RET_ERROR
;
}
(
void
)
close
(
fd
);
onnx_file
.
release
();
return
RET_OK
;
}
STATUS
OnnxModelParser
::
SetGraphConstTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
TensorCache
*
tensor_cache
)
{
MS_LOG
(
DEBUG
)
<<
"set onnx constant tensors"
;
MS_LOG
(
DEBUG
)
<<
"set onnx constant tensors"
;
for
(
const
auto
&
onnx_const_value
:
onnx_graph
.
initializer
())
{
for
(
const
auto
&
onnx_const_value
:
onnx_graph
.
initializer
())
{
int
index
;
int
index
;
...
@@ -119,11 +91,8 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
...
@@ -119,11 +91,8 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
OnnxModelParser
::
AddValueInfo
(
const
onnx
::
ValueInfoProto
&
proto
,
STATUS
OnnxModelParser
::
AddValueInfo
(
const
onnx
::
ValueInfoProto
&
proto
,
const
std
::
string
&
name
,
const
TensorType
&
type
,
const
std
::
string
&
name
,
TensorCache
*
tensor_cache
,
int
*
index
)
{
const
TensorType
&
type
,
TensorCache
*
tensor_cache
,
int
*
index
)
{
auto
data_type
=
GetDataTypeFromOnnx
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
proto
.
type
().
tensor_type
().
elem_type
()));
auto
data_type
=
GetDataTypeFromOnnx
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
proto
.
type
().
tensor_type
().
elem_type
()));
if
(
data_type
==
kTypeUnknown
)
{
if
(
data_type
==
kTypeUnknown
)
{
MS_LOG
(
ERROR
)
<<
"not support onnx data type "
MS_LOG
(
ERROR
)
<<
"not support onnx data type "
...
@@ -143,11 +112,8 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto,
...
@@ -143,11 +112,8 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
OnnxModelParser
::
AddTensorProto
(
const
onnx
::
TensorProto
&
proto
,
STATUS
OnnxModelParser
::
AddTensorProto
(
const
onnx
::
TensorProto
&
proto
,
const
std
::
string
&
name
,
const
TensorType
&
type
,
const
std
::
string
&
name
,
TensorCache
*
tensor_cache
,
int
*
index
)
{
const
TensorType
&
type
,
TensorCache
*
tensor_cache
,
int
*
index
)
{
auto
data_type
=
GetDataTypeFromOnnx
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
proto
.
data_type
()));
auto
data_type
=
GetDataTypeFromOnnx
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
proto
.
data_type
()));
if
(
data_type
==
kTypeUnknown
)
{
if
(
data_type
==
kTypeUnknown
)
{
MS_LOG
(
ERROR
)
<<
"not support onnx data type "
<<
static_cast
<
onnx
::
TensorProto_DataType
>
(
proto
.
data_type
());
MS_LOG
(
ERROR
)
<<
"not support onnx data type "
<<
static_cast
<
onnx
::
TensorProto_DataType
>
(
proto
.
data_type
());
...
@@ -174,8 +140,7 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto,
...
@@ -174,8 +140,7 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
OnnxModelParser
::
SetGraphInputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
STATUS
OnnxModelParser
::
SetGraphInputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
schema
::
MetaGraphT
*
graph
,
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
)
{
TensorCache
*
tensor_cache
)
{
for
(
const
auto
&
input_value
:
onnx_graph
.
input
())
{
for
(
const
auto
&
input_value
:
onnx_graph
.
input
())
{
auto
ret
=
tensor_cache
->
FindTensor
(
input_value
.
name
());
auto
ret
=
tensor_cache
->
FindTensor
(
input_value
.
name
());
...
@@ -192,8 +157,7 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
...
@@ -192,8 +157,7 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
OnnxModelParser
::
SetGraphOutputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
STATUS
OnnxModelParser
::
SetGraphOutputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
schema
::
MetaGraphT
*
graph
,
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
)
{
TensorCache
*
tensor_cache
)
{
for
(
const
auto
&
output_value
:
onnx_graph
.
output
())
{
for
(
const
auto
&
output_value
:
onnx_graph
.
output
())
{
int
index
;
int
index
;
...
@@ -207,10 +171,8 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
...
@@ -207,10 +171,8 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
return
RET_OK
;
return
RET_OK
;
}
}
void
OnnxModelParser
::
ParseOnnxGemmNode
(
const
onnx
::
GraphProto
&
onnx_graph
,
void
OnnxModelParser
::
ParseOnnxGemmNode
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
const
onnx
::
NodeProto
&
onnx_node
,
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
)
{
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
)
{
std
::
unique_ptr
<
schema
::
CNodeT
>
dst_op_1
=
std
::
make_unique
<
schema
::
CNodeT
>
();
std
::
unique_ptr
<
schema
::
CNodeT
>
dst_op_1
=
std
::
make_unique
<
schema
::
CNodeT
>
();
dst_op_1
->
name
=
"Gemm_MatMul_"
+
onnx_node
.
output
(
0
);
dst_op_1
->
name
=
"Gemm_MatMul_"
+
onnx_node
.
output
(
0
);
ParseOnnxNodeAttr
(
onnx_graph
,
onnx_node
,
"MatMul"
,
dst_op_1
.
get
());
ParseOnnxNodeAttr
(
onnx_graph
,
onnx_node
,
"MatMul"
,
dst_op_1
.
get
());
...
@@ -231,8 +193,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
...
@@ -231,8 +193,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
graph
->
nodes
.
emplace_back
(
std
::
move
(
dst_op_2
));
graph
->
nodes
.
emplace_back
(
std
::
move
(
dst_op_2
));
}
}
STATUS
OnnxModelParser
::
ParseOnnxGivenFillNode
(
const
onnx
::
NodeProto
&
onnx_node
,
STATUS
OnnxModelParser
::
ParseOnnxGivenFillNode
(
const
onnx
::
NodeProto
&
onnx_node
,
TensorCache
*
tensor_cache
)
{
TensorCache
*
tensor_cache
)
{
// convert GivenTensorFill node to a weight/bias tensor
// convert GivenTensorFill node to a weight/bias tensor
auto
ret
=
tensor_cache
->
FindTensor
(
onnx_node
.
output
(
0
));
auto
ret
=
tensor_cache
->
FindTensor
(
onnx_node
.
output
(
0
));
if
(
ret
<
0
)
{
if
(
ret
<
0
)
{
...
@@ -284,10 +245,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
...
@@ -284,10 +245,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
OnnxModelParser
::
ParseOnnxNodeToDstOp
(
const
onnx
::
GraphProto
&
onnx_graph
,
STATUS
OnnxModelParser
::
ParseOnnxNodeToDstOp
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
const
onnx
::
NodeProto
&
onnx_node
,
schema
::
CNodeT
*
dst_op
,
schema
::
TensorT
*
dst_tensor
,
schema
::
CNodeT
*
dst_op
,
schema
::
TensorT
*
dst_tensor
,
TensorCache
*
tensor_cache
)
{
TensorCache
*
tensor_cache
)
{
// change op_type() to name(), that is unique
// change op_type() to name(), that is unique
dst_op
->
name
=
onnx_node
.
op_type
()
+
"_"
+
onnx_node
.
output
(
0
);
dst_op
->
name
=
onnx_node
.
op_type
()
+
"_"
+
onnx_node
.
output
(
0
);
...
@@ -319,11 +278,8 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
...
@@ -319,11 +278,8 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
return
RET_OK
;
return
RET_OK
;
}
}
void
OnnxModelParser
::
SetOpQuantParams
(
const
onnx
::
GraphProto
&
onnx_graph
,
void
OnnxModelParser
::
SetOpQuantParams
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
const
onnx
::
NodeProto
&
onnx_node
,
schema
::
CNodeT
*
dst_op
,
schema
::
TensorT
*
dst_tensor
,
TensorCache
*
tensor_cache
)
{
schema
::
CNodeT
*
dst_op
,
schema
::
TensorT
*
dst_tensor
,
TensorCache
*
tensor_cache
)
{
MS_ASSERT
(
dst_op
!=
nullptr
);
MS_ASSERT
(
dst_op
!=
nullptr
);
MS_ASSERT
(
tensor_cache
!=
nullptr
);
MS_ASSERT
(
tensor_cache
!=
nullptr
);
std
::
vector
<
string
>
quant_node_name
;
std
::
vector
<
string
>
quant_node_name
;
...
@@ -380,10 +336,8 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph,
...
@@ -380,10 +336,8 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph,
}
}
}
}
STATUS
OnnxModelParser
::
ParseOnnxNodeAttr
(
const
onnx
::
GraphProto
&
onnx_graph
,
STATUS
OnnxModelParser
::
ParseOnnxNodeAttr
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
const
onnx
::
NodeProto
&
onnx_node
,
const
string
&
onnx_op_type
,
schema
::
CNodeT
*
dst_op
)
{
const
string
&
onnx_op_type
,
schema
::
CNodeT
*
dst_op
)
{
auto
node_parser
=
OnnxNodeParserRegistry
::
GetInstance
()
->
GetNodeParser
(
onnx_op_type
);
auto
node_parser
=
OnnxNodeParserRegistry
::
GetInstance
()
->
GetNodeParser
(
onnx_op_type
);
if
(
node_parser
==
nullptr
)
{
if
(
node_parser
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"not find "
<<
onnx_op_type
<<
", node parser is nullptr"
;
MS_LOG
(
EXCEPTION
)
<<
"not find "
<<
onnx_op_type
<<
", node parser is nullptr"
;
...
@@ -392,10 +346,8 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
...
@@ -392,10 +346,8 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
return
node_parser
->
Parse
(
onnx_graph
,
onnx_node
,
dst_op
);
return
node_parser
->
Parse
(
onnx_graph
,
onnx_node
,
dst_op
);
}
}
STATUS
OnnxModelParser
::
SetOpInputIndex
(
const
std
::
vector
<
string
>
&
node_inputs
,
STATUS
OnnxModelParser
::
SetOpInputIndex
(
const
std
::
vector
<
string
>
&
node_inputs
,
schema
::
CNodeT
*
dst_op
,
schema
::
CNodeT
*
dst_op
,
const
onnx
::
NodeProto
&
onnx_node
,
TensorCache
*
tensor_cache
)
{
const
onnx
::
NodeProto
&
onnx_node
,
TensorCache
*
tensor_cache
)
{
for
(
const
auto
&
onnx_node_input
:
node_inputs
)
{
for
(
const
auto
&
onnx_node_input
:
node_inputs
)
{
auto
index
=
tensor_cache
->
FindTensor
(
onnx_node_input
);
auto
index
=
tensor_cache
->
FindTensor
(
onnx_node_input
);
if
(
index
<
0
)
{
if
(
index
<
0
)
{
...
@@ -408,8 +360,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
...
@@ -408,8 +360,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
OnnxModelParser
::
SetOpOutputIndex
(
const
std
::
vector
<
string
>
&
node_outputs
,
STATUS
OnnxModelParser
::
SetOpOutputIndex
(
const
std
::
vector
<
string
>
&
node_outputs
,
schema
::
CNodeT
*
dst_op
,
schema
::
CNodeT
*
dst_op
,
TensorCache
*
tensor_cache
)
{
TensorCache
*
tensor_cache
)
{
for
(
const
auto
&
onnx_node_output
:
node_outputs
)
{
for
(
const
auto
&
onnx_node_output
:
node_outputs
)
{
auto
index
=
tensor_cache
->
FindTensor
(
onnx_node_output
);
auto
index
=
tensor_cache
->
FindTensor
(
onnx_node_output
);
...
@@ -424,8 +375,7 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
...
@@ -424,8 +375,7 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
OnnxModelParser
::
CopyOnnxTensorData
(
const
onnx
::
TensorProto
&
onnx_const_value
,
STATUS
OnnxModelParser
::
CopyOnnxTensorData
(
const
onnx
::
TensorProto
&
onnx_const_value
,
schema
::
TensorT
*
tensor
)
{
schema
::
TensorT
*
tensor
)
{
size_t
data_count
=
1
;
size_t
data_count
=
1
;
std
::
for_each
(
tensor
->
dims
.
begin
(),
tensor
->
dims
.
end
(),
[
&
data_count
](
int
dim
)
{
data_count
*=
dim
;
});
std
::
for_each
(
tensor
->
dims
.
begin
(),
tensor
->
dims
.
end
(),
[
&
data_count
](
int
dim
)
{
data_count
*=
dim
;
});
size_t
data_size
=
0
;
size_t
data_size
=
0
;
...
@@ -484,8 +434,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
...
@@ -484,8 +434,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
return
RET_OK
;
return
RET_OK
;
}
}
STATUS
OnnxModelParser
::
SetAllTensors
(
const
TensorCache
&
tensor_cache
,
STATUS
OnnxModelParser
::
SetAllTensors
(
const
TensorCache
&
tensor_cache
,
schema
::
MetaGraphT
*
graphDef
)
{
schema
::
MetaGraphT
*
graphDef
)
{
std
::
vector
<
schema
::
TensorT
*>
tensors
=
tensor_cache
.
GetCachedTensor
();
std
::
vector
<
schema
::
TensorT
*>
tensors
=
tensor_cache
.
GetCachedTensor
();
for
(
auto
iter
:
tensors
)
{
for
(
auto
iter
:
tensors
)
{
std
::
unique_ptr
<
schema
::
TensorT
>
temp
(
iter
);
std
::
unique_ptr
<
schema
::
TensorT
>
temp
(
iter
);
...
@@ -507,17 +456,16 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
...
@@ -507,17 +456,16 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
}
}
}
}
schema
::
MetaGraphT
*
OnnxModelParser
::
ParseToFb
(
const
std
::
string
&
modelFile
,
schema
::
MetaGraphT
*
OnnxModelParser
::
ParseToFb
(
const
std
::
string
&
modelFile
,
const
std
::
string
&
weightFile
,
const
std
::
string
&
weightFile
,
const
QuantType
&
quantType
)
{
const
QuantType
&
quantType
)
{
if
(
ValidateFileStr
(
modelFile
,
".onnx"
)
!=
RET_OK
)
{
if
(
ValidateFileStr
(
modelFile
,
".onnx"
)
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Input illegal: modelFile must be *.onnx"
;
MS_LOG
(
ERROR
)
<<
"Input illegal: modelFile must be *.onnx"
;
return
nullptr
;
return
nullptr
;
}
}
auto
dst_graph
=
std
::
make_unique
<
schema
::
MetaGraphT
>
();
onnx
::
ModelProto
onnx_model
;
onnx
::
ModelProto
onnx_model
;
if
(
Read
OnnxModelFromBinary
(
modelFile
,
&
onnx_model
)
!=
RET_OK
)
{
if
(
Read
ProtoFromBinaryFile
((
const
char
*
)
modelFile
.
c_str
()
,
&
onnx_model
)
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"
read onnx model fail"
;
MS_LOG
(
ERROR
)
<<
"
Read onnx model file failed, model path: "
<<
modelFile
;
return
nullptr
;
return
nullptr
;
}
}
const
onnx
::
GraphProto
&
onnx_graph
=
onnx_model
.
graph
();
const
onnx
::
GraphProto
&
onnx_graph
=
onnx_model
.
graph
();
...
@@ -531,6 +479,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile,
...
@@ -531,6 +479,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile,
MS_LOG
(
ERROR
)
<<
"SetGraphConstTensor failed"
;
MS_LOG
(
ERROR
)
<<
"SetGraphConstTensor failed"
;
return
nullptr
;
return
nullptr
;
}
}
auto
dst_graph
=
std
::
make_unique
<
schema
::
MetaGraphT
>
();
// init onnx model graph input tensor
// init onnx model graph input tensor
if
(
SetGraphInputTensor
(
onnx_graph
,
dst_graph
.
get
(),
&
tensor_cache
))
{
if
(
SetGraphInputTensor
(
onnx_graph
,
dst_graph
.
get
(),
&
tensor_cache
))
{
MS_LOG
(
ERROR
)
<<
"SetGraphInputTensor failed"
;
MS_LOG
(
ERROR
)
<<
"SetGraphInputTensor failed"
;
...
...
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
浏览文件 @
67f954a5
...
@@ -41,78 +41,47 @@ class OnnxModelParser : public ModelParser {
...
@@ -41,78 +41,47 @@ class OnnxModelParser : public ModelParser {
virtual
~
OnnxModelParser
();
virtual
~
OnnxModelParser
();
schema
::
MetaGraphT
*
ParseToFb
(
const
std
::
string
&
modelFile
,
const
std
::
string
&
weightFile
,
schema
::
MetaGraphT
*
ParseToFb
(
const
std
::
string
&
modelFile
,
const
std
::
string
&
weightFile
,
const
QuantType
&
quantType
=
QuantType_QUANT_NONE
)
override
;
const
QuantType
&
quantType
=
QuantType_QUANT_NONE
)
override
;
private:
private:
TypeId
GetDataTypeFromOnnx
(
onnx
::
TensorProto_DataType
onnx_type
);
TypeId
GetDataTypeFromOnnx
(
onnx
::
TensorProto_DataType
onnx_type
);
std
::
vector
<
int32_t
>
GetDimsFromOnnxValue
(
const
onnx
::
ValueInfoProto
&
onnx_value
);
std
::
vector
<
int32_t
>
GetDimsFromOnnxValue
(
const
onnx
::
ValueInfoProto
&
onnx_value
);
STATUS
ReadOnnxModelFromBinary
(
const
std
::
string
&
modelFile
,
STATUS
SetGraphConstTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
TensorCache
*
tensor_cache
);
google
::
protobuf
::
Message
*
model_proto
);
STATUS
SetGraphInputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
);
STATUS
SetGraphConstTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
TensorCache
*
tensor_cache
);
STATUS
SetGraphOutputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
);
STATUS
SetGraphInputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
STATUS
AddValueInfo
(
const
onnx
::
ValueInfoProto
&
proto
,
const
std
::
string
&
name
,
const
TensorType
&
type
,
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
,
int
*
index
);
TensorCache
*
tensor_cache
);
STATUS
AddTensorProto
(
const
onnx
::
TensorProto
&
proto
,
const
std
::
string
&
name
,
const
TensorType
&
type
,
STATUS
SetGraphOutputTensor
(
const
onnx
::
GraphProto
&
onnx_graph
,
TensorCache
*
tensor_cache
,
int
*
index
);
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
);
STATUS
ParseOnnxNodeToDstOp
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
schema
::
CNodeT
*
dst_op
,
schema
::
TensorT
*
dst_tensor
,
TensorCache
*
tensor_cache
);
STATUS
AddValueInfo
(
const
onnx
::
ValueInfoProto
&
proto
,
const
std
::
string
&
name
,
void
ParseOnnxGemmNode
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
const
TensorType
&
type
,
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
);
TensorCache
*
tensor_cache
,
int
*
index
);
STATUS
ParseOnnxGivenFillNode
(
const
onnx
::
NodeProto
&
onnx_node
,
TensorCache
*
tensor_cache
);
STATUS
AddTensorProto
(
const
onnx
::
TensorProto
&
proto
,
STATUS
ParseOnnxNodeAttr
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
const
std
::
string
&
name
,
const
string
&
onnx_op_type
,
schema
::
CNodeT
*
dst_op
);
const
TensorType
&
type
,
TensorCache
*
tensor_cache
,
void
SetOpQuantParams
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
schema
::
CNodeT
*
dst_op
,
int
*
index
);
schema
::
TensorT
*
dst_tensor
,
TensorCache
*
tensor_cache
);
STATUS
ParseOnnxNodeToDstOp
(
const
onnx
::
GraphProto
&
onnx_graph
,
STATUS
SetOpInputIndex
(
const
std
::
vector
<
string
>
&
node_inputs
,
schema
::
CNodeT
*
dst_op
,
const
onnx
::
NodeProto
&
onnx_node
,
const
onnx
::
NodeProto
&
onnx_node
,
TensorCache
*
tensor_cache
);
schema
::
CNodeT
*
dst_op
,
schema
::
TensorT
*
dst_tensor
,
STATUS
SetOpOutputIndex
(
const
std
::
vector
<
string
>
&
node_outputs
,
schema
::
CNodeT
*
dst_op
,
TensorCache
*
tensor_cache
);
TensorCache
*
tensor_cache
);
STATUS
CopyOnnxTensorData
(
const
onnx
::
TensorProto
&
onnx_init_value
,
schema
::
TensorT
*
tensor
);
void
ParseOnnxGemmNode
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
STATUS
SetAllTensors
(
const
TensorCache
&
tensor_cache
,
schema
::
MetaGraphT
*
graphDef
);
schema
::
MetaGraphT
*
graph
,
TensorCache
*
tensor_cache
);
STATUS
ParseOnnxGivenFillNode
(
const
onnx
::
NodeProto
&
onnx_node
,
TensorCache
*
tensor_cache
);
STATUS
ParseOnnxNodeAttr
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
const
string
&
onnx_op_type
,
schema
::
CNodeT
*
dst_op
);
void
SetOpQuantParams
(
const
onnx
::
GraphProto
&
onnx_graph
,
const
onnx
::
NodeProto
&
onnx_node
,
schema
::
CNodeT
*
dst_op
,
schema
::
TensorT
*
dst_tensor
,
TensorCache
*
tensor_cache
);
STATUS
SetOpInputIndex
(
const
std
::
vector
<
string
>
&
node_inputs
,
schema
::
CNodeT
*
dst_op
,
const
onnx
::
NodeProto
&
onnx_node
,
TensorCache
*
tensor_cache
);
STATUS
SetOpOutputIndex
(
const
std
::
vector
<
string
>
&
node_outputs
,
schema
::
CNodeT
*
dst_op
,
TensorCache
*
tensor_cache
);
STATUS
CopyOnnxTensorData
(
const
onnx
::
TensorProto
&
onnx_init_value
,
schema
::
TensorT
*
tensor
);
STATUS
SetAllTensors
(
const
TensorCache
&
tensor_cache
,
schema
::
MetaGraphT
*
graphDef
);
void
FindGraphInputAndConst
(
const
onnx
::
GraphProto
&
onnx_graph
);
void
FindGraphInputAndConst
(
const
onnx
::
GraphProto
&
onnx_graph
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录