提交 67f954a5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5455 fix onnx | mindir read protobuf bug in windows

Merge pull request !5455 from hangq/master
...@@ -230,6 +230,7 @@ if(BUILD_CONVERTER) ...@@ -230,6 +230,7 @@ if(BUILD_CONVERTER)
${TEST_LITE_SRC} ${TEST_LITE_SRC}
${TEST_CASE_TFLITE_PARSERS_SRC} ${TEST_CASE_TFLITE_PARSERS_SRC}
${TOP_DIR}/mindspore/core/utils/flags.cc ${TOP_DIR}/mindspore/core/utils/flags.cc
${LITE_DIR}/tools/common/protobuf_utils.cc
${LITE_DIR}/tools/converter/optimizer.cc ${LITE_DIR}/tools/converter/optimizer.cc
${LITE_DIR}/tools/converter/anf_transform.cc ${LITE_DIR}/tools/converter/anf_transform.cc
${LITE_DIR}/tools/converter/graphdef_transform.cc ${LITE_DIR}/tools/converter/graphdef_transform.cc
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <vector> #include <vector>
#include "src/ops/primitive_c.h" #include "src/ops/primitive_c.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
...@@ -37,6 +36,7 @@ ...@@ -37,6 +36,7 @@
#include "src/param_value_lite.h" #include "src/param_value_lite.h"
#include "tools/converter/parser/onnx/onnx.pb.h" #include "tools/converter/parser/onnx/onnx.pb.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "tools/common/protobuf_utils.h"
using string = std::string; using string = std::string;
using int32 = int32_t; using int32 = int32_t;
...@@ -651,31 +651,11 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { ...@@ -651,31 +651,11 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
} }
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
std::unique_ptr<char[]> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
#ifdef _WIN32
if (_fullpath(onnx_file.get(), model_path.c_str(), 1024) == nullptr) {
MS_LOG(ERROR) << "open file failed.";
return nullptr;
}
#else
if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) {
MS_LOG(ERROR) << "open file failed.";
return nullptr;
}
#endif
int fd = open(onnx_file.get(), O_RDONLY);
google::protobuf::io::FileInputStream input(fd);
google::protobuf::io::CodedInputStream code_input(&input);
code_input.SetTotalBytesLimit(INT_MAX, 536870912);
auto onnx_model = new onnx::ModelProto; auto onnx_model = new onnx::ModelProto;
bool ret = onnx_model->ParseFromCodedStream(&code_input); if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) {
if (!ret) { MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path;
MS_LOG(ERROR) << "load onnx file failed";
delete onnx_model;
return nullptr; return nullptr;
} }
(void)close(fd);
MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl;
return onnx_model; return onnx_model;
} }
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" #include "tools/common/protobuf_utils.h"
#include <fstream> #include <fstream>
#include <string> #include <string>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
...@@ -37,15 +37,14 @@ bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded ...@@ -37,15 +37,14 @@ bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded
return proto->ParseFromCodedStream(coded_stream); return proto->ParseFromCodedStream(coded_stream);
} }
STATUS ReadProtoFromText(const char *file, STATUS ReadProtoFromText(const char *file, google::protobuf::Message *message) {
google::protobuf::Message *message) {
if (file == nullptr || message == nullptr) { if (file == nullptr || message == nullptr) {
return RET_ERROR; return RET_ERROR;
} }
std::string realPath = RealPath(file); std::string realPath = RealPath(file);
if (realPath.empty()) { if (realPath.empty()) {
MS_LOG(ERROR) << "Proto file path " << file <<" is not valid"; MS_LOG(ERROR) << "Proto file path " << file << " is not valid";
return RET_ERROR; return RET_ERROR;
} }
...@@ -67,8 +66,7 @@ STATUS ReadProtoFromText(const char *file, ...@@ -67,8 +66,7 @@ STATUS ReadProtoFromText(const char *file,
return RET_OK; return RET_OK;
} }
STATUS ReadProtoFromBinaryFile(const char *file, STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *message) {
google::protobuf::Message *message) {
if (file == nullptr || message == nullptr) { if (file == nullptr || message == nullptr) {
return RET_ERROR; return RET_ERROR;
} }
...@@ -100,4 +98,3 @@ STATUS ReadProtoFromBinaryFile(const char *file, ...@@ -100,4 +98,3 @@ STATUS ReadProtoFromBinaryFile(const char *file,
} }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
...@@ -29,13 +29,10 @@ namespace lite { ...@@ -29,13 +29,10 @@ namespace lite {
bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded_stream, bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded_stream,
google::protobuf::Message *proto); google::protobuf::Message *proto);
STATUS ReadProtoFromText(const char *file, STATUS ReadProtoFromText(const char *file, google::protobuf::Message *message);
google::protobuf::Message *message);
STATUS ReadProtoFromBinaryFile(const char *file, STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *message);
google::protobuf::Message *message);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_
...@@ -94,6 +94,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ...@@ -94,6 +94,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/protobuf_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc
......
...@@ -15,7 +15,6 @@ add_library(caffe_parser_mid OBJECT ...@@ -15,7 +15,6 @@ add_library(caffe_parser_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/caffe_model_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_model_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser_registry.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/caffe_parse_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/caffe_pooling_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_pooling_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/caffe_power_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_power_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/caffe_prelu_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_prelu_parser.cc
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
*/ */
#include "mindspore/lite/tools/converter/parser/caffe/caffe_converter.h" #include "mindspore/lite/tools/converter/parser/caffe/caffe_converter.h"
#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
......
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
* limitations under the License. * limitations under the License.
*/ */
#include "mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h" #include "tools/converter/parser/caffe/caffe_model_parser.h"
#include <vector> #include <vector>
#include <iostream> #include <iostream>
#include <utility> #include <utility>
#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" #include "tools/converter/parser/caffe/caffe_node_parser_registry.h"
#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" #include "tools/converter/parser/caffe/caffe_inspector.h"
#include "mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
...@@ -31,9 +31,8 @@ CaffeModelParser::~CaffeModelParser() {} ...@@ -31,9 +31,8 @@ CaffeModelParser::~CaffeModelParser() {}
const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"}; const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"};
schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const std::string &weightFile, const QuantType &quantType) {
const QuantType &quantType) {
if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt";
return nullptr; return nullptr;
...@@ -89,8 +88,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, ...@@ -89,8 +88,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile,
return metaGraph.release(); return metaGraph.release();
} }
STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op,
schema::CNodeT *op,
TensorCache *tensorCache) { TensorCache *tensorCache) {
for (int i = 0; i < layer.bottom_size(); i++) { for (int i = 0; i < layer.bottom_size(); i++) {
int index = tensorCache->FindTensor(layer.bottom(i)); int index = tensorCache->FindTensor(layer.bottom(i));
...@@ -104,8 +102,7 @@ STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, ...@@ -104,8 +102,7 @@ STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer,
return RET_OK; return RET_OK;
} }
STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op,
schema::CNodeT *op,
TensorCache *tensorCache) { TensorCache *tensorCache) {
for (int i = 0; i < layer.top_size(); i++) { for (int i = 0; i < layer.top_size(); i++) {
std::unique_ptr<schema::TensorT> msTensor = std::make_unique<schema::TensorT>(); std::unique_ptr<schema::TensorT> msTensor = std::make_unique<schema::TensorT>();
...@@ -114,8 +111,7 @@ STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, ...@@ -114,8 +111,7 @@ STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer,
return RET_OK; return RET_OK;
} }
STATUS CaffeModelParser::SetWeightTensor(const std::vector<schema::TensorT *> &weightVec, STATUS CaffeModelParser::SetWeightTensor(const std::vector<schema::TensorT *> &weightVec, schema::CNodeT *op,
schema::CNodeT *op,
TensorCache *tensorCache) { TensorCache *tensorCache) {
for (auto iter : weightVec) { for (auto iter : weightVec) {
op->inputIndex.emplace_back(tensorCache->AddTensor("Weight", iter, CONST)); op->inputIndex.emplace_back(tensorCache->AddTensor("Weight", iter, CONST));
...@@ -123,8 +119,7 @@ STATUS CaffeModelParser::SetWeightTensor(const std::vector<schema::TensorT *> &w ...@@ -123,8 +119,7 @@ STATUS CaffeModelParser::SetWeightTensor(const std::vector<schema::TensorT *> &w
return RET_OK; return RET_OK;
} }
STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, schema::MetaGraphT *subGraphDef) {
schema::MetaGraphT *subGraphDef) {
std::vector<schema::TensorT *> tensors = tensorCache.GetCachedTensor(); std::vector<schema::TensorT *> tensors = tensorCache.GetCachedTensor();
for (auto iter : tensors) { for (auto iter : tensors) {
std::unique_ptr<schema::TensorT> temp(iter); std::unique_ptr<schema::TensorT> temp(iter);
...@@ -133,8 +128,7 @@ STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, ...@@ -133,8 +128,7 @@ STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache,
return RET_OK; return RET_OK;
} }
STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, TensorCache *tensorCache,
TensorCache *tensorCache,
schema::MetaGraphT *subGraphDef) { schema::MetaGraphT *subGraphDef) {
CaffeInspector caffeInspector; CaffeInspector caffeInspector;
caffeInspector.InspectModel(proto); caffeInspector.InspectModel(proto);
...@@ -160,10 +154,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, ...@@ -160,10 +154,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto,
return RET_OK; return RET_OK;
} }
STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight,
const caffe::NetParameter &weight, TensorCache *tensorCache, schema::MetaGraphT *subGraphDef) {
TensorCache *tensorCache,
schema::MetaGraphT *subGraphDef) {
for (int i = 0; i < proto.layer_size(); i++) { for (int i = 0; i < proto.layer_size(); i++) {
auto layer = proto.layer(i); auto layer = proto.layer(i);
...@@ -235,8 +227,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, ...@@ -235,8 +227,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto,
return RET_OK; return RET_OK;
} }
STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache) {
TensorCache *tensorCache) {
for (int i = 0; i < proto.input_size(); i++) { for (int i = 0; i < proto.input_size(); i++) {
if (proto.input_dim_size() <= 0) { if (proto.input_dim_size() <= 0) {
continue; continue;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <utility> #include <utility>
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "tools/common/protobuf_utils.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
...@@ -54,36 +55,7 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo ...@@ -54,36 +55,7 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo
return dims; return dims;
} }
STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) {
google::protobuf::Message *onnx_model) {
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
#ifdef _WIN32
if (_fullpath(onnx_file.get(), modelFile.c_str(), 1024) == nullptr) {
MS_LOG(ERROR) << "get realpath " << modelFile << " fail";
return RET_ERROR;
}
#else
if (realpath(modelFile.c_str(), onnx_file.get()) == nullptr) {
MS_LOG(ERROR) << "get realpath " << modelFile << " fail";
return RET_ERROR;
}
#endif
int fd = open(onnx_file.get(), O_RDONLY);
google::protobuf::io::FileInputStream input(fd);
google::protobuf::io::CodedInputStream code_input(&input);
code_input.SetTotalBytesLimit(INT_MAX, 536870912);
bool ret = onnx_model->ParseFromCodedStream(&code_input);
if (!ret) {
MS_LOG(ERROR) << "load onnx file failed";
return RET_ERROR;
}
(void)close(fd);
onnx_file.release();
return RET_OK;
}
STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
TensorCache *tensor_cache) {
MS_LOG(DEBUG) << "set onnx constant tensors"; MS_LOG(DEBUG) << "set onnx constant tensors";
for (const auto &onnx_const_value : onnx_graph.initializer()) { for (const auto &onnx_const_value : onnx_graph.initializer()) {
int index; int index;
...@@ -119,11 +91,8 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, ...@@ -119,11 +91,8 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type,
const std::string &name, TensorCache *tensor_cache, int *index) {
const TensorType &type,
TensorCache *tensor_cache,
int *index) {
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
if (data_type == kTypeUnknown) { if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "not support onnx data type " MS_LOG(ERROR) << "not support onnx data type "
...@@ -143,11 +112,8 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, ...@@ -143,11 +112,8 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto,
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type,
const std::string &name, TensorCache *tensor_cache, int *index) {
const TensorType &type,
TensorCache *tensor_cache,
int *index) {
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type())); auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type()));
if (data_type == kTypeUnknown) { if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type()); MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type());
...@@ -174,8 +140,7 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, ...@@ -174,8 +140,7 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto,
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
schema::MetaGraphT *graph,
TensorCache *tensor_cache) { TensorCache *tensor_cache) {
for (const auto &input_value : onnx_graph.input()) { for (const auto &input_value : onnx_graph.input()) {
auto ret = tensor_cache->FindTensor(input_value.name()); auto ret = tensor_cache->FindTensor(input_value.name());
...@@ -192,8 +157,7 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, ...@@ -192,8 +157,7 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
schema::MetaGraphT *graph,
TensorCache *tensor_cache) { TensorCache *tensor_cache) {
for (const auto &output_value : onnx_graph.output()) { for (const auto &output_value : onnx_graph.output()) {
int index; int index;
...@@ -207,10 +171,8 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, ...@@ -207,10 +171,8 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
return RET_OK; return RET_OK;
} }
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const onnx::NodeProto &onnx_node, schema::MetaGraphT *graph, TensorCache *tensor_cache) {
schema::MetaGraphT *graph,
TensorCache *tensor_cache) {
std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>(); std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>();
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get());
...@@ -231,8 +193,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, ...@@ -231,8 +193,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
graph->nodes.emplace_back(std::move(dst_op_2)); graph->nodes.emplace_back(std::move(dst_op_2));
} }
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
TensorCache *tensor_cache) {
// convert GivenTensorFill node to a weight/bias tensor // convert GivenTensorFill node to a weight/bias tensor
auto ret = tensor_cache->FindTensor(onnx_node.output(0)); auto ret = tensor_cache->FindTensor(onnx_node.output(0));
if (ret < 0) { if (ret < 0) {
...@@ -284,10 +245,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, ...@@ -284,10 +245,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache) { TensorCache *tensor_cache) {
// change op_type() to name(), that is unique // change op_type() to name(), that is unique
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
...@@ -319,11 +278,8 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, ...@@ -319,11 +278,8 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
return RET_OK; return RET_OK;
} }
void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) {
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache) {
MS_ASSERT(dst_op != nullptr); MS_ASSERT(dst_op != nullptr);
MS_ASSERT(tensor_cache != nullptr); MS_ASSERT(tensor_cache != nullptr);
std::vector<string> quant_node_name; std::vector<string> quant_node_name;
...@@ -380,10 +336,8 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, ...@@ -380,10 +336,8 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph,
} }
} }
STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const onnx::NodeProto &onnx_node, const string &onnx_op_type, schema::CNodeT *dst_op) {
const string &onnx_op_type,
schema::CNodeT *dst_op) {
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type);
if (node_parser == nullptr) { if (node_parser == nullptr) {
MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr"; MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr";
...@@ -392,10 +346,8 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, ...@@ -392,10 +346,8 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
return node_parser->Parse(onnx_graph, onnx_node, dst_op); return node_parser->Parse(onnx_graph, onnx_node, dst_op);
} }
STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache) {
for (const auto &onnx_node_input : node_inputs) { for (const auto &onnx_node_input : node_inputs) {
auto index = tensor_cache->FindTensor(onnx_node_input); auto index = tensor_cache->FindTensor(onnx_node_input);
if (index < 0) { if (index < 0) {
...@@ -408,8 +360,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, ...@@ -408,8 +360,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op,
schema::CNodeT *dst_op,
TensorCache *tensor_cache) { TensorCache *tensor_cache) {
for (const auto &onnx_node_output : node_outputs) { for (const auto &onnx_node_output : node_outputs) {
auto index = tensor_cache->FindTensor(onnx_node_output); auto index = tensor_cache->FindTensor(onnx_node_output);
...@@ -424,8 +375,7 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs ...@@ -424,8 +375,7 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) {
schema::TensorT *tensor) {
size_t data_count = 1; size_t data_count = 1;
std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; });
size_t data_size = 0; size_t data_size = 0;
...@@ -484,8 +434,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v ...@@ -484,8 +434,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) {
schema::MetaGraphT *graphDef) {
std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor(); std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor();
for (auto iter : tensors) { for (auto iter : tensors) {
std::unique_ptr<schema::TensorT> temp(iter); std::unique_ptr<schema::TensorT> temp(iter);
...@@ -507,17 +456,16 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) ...@@ -507,17 +456,16 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
} }
} }
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const std::string &weightFile, const QuantType &quantType) {
const QuantType &quantType) {
if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { if (ValidateFileStr(modelFile, ".onnx") != RET_OK) {
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
return nullptr; return nullptr;
} }
auto dst_graph = std::make_unique<schema::MetaGraphT>();
onnx::ModelProto onnx_model; onnx::ModelProto onnx_model;
if (ReadOnnxModelFromBinary(modelFile, &onnx_model) != RET_OK) { if (ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model) != RET_OK) {
MS_LOG(ERROR) << "read onnx model fail"; MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
return nullptr; return nullptr;
} }
const onnx::GraphProto &onnx_graph = onnx_model.graph(); const onnx::GraphProto &onnx_graph = onnx_model.graph();
...@@ -531,6 +479,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, ...@@ -531,6 +479,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile,
MS_LOG(ERROR) << "SetGraphConstTensor failed"; MS_LOG(ERROR) << "SetGraphConstTensor failed";
return nullptr; return nullptr;
} }
auto dst_graph = std::make_unique<schema::MetaGraphT>();
// init onnx model graph input tensor // init onnx model graph input tensor
if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) {
MS_LOG(ERROR) << "SetGraphInputTensor failed"; MS_LOG(ERROR) << "SetGraphInputTensor failed";
......
...@@ -41,78 +41,47 @@ class OnnxModelParser : public ModelParser { ...@@ -41,78 +41,47 @@ class OnnxModelParser : public ModelParser {
virtual ~OnnxModelParser(); virtual ~OnnxModelParser();
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override; const QuantType &quantType = QuantType_QUANT_NONE) override;
private: private:
TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
STATUS ReadOnnxModelFromBinary(const std::string &modelFile, STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);
google::protobuf::Message *model_proto);
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
TensorCache *tensor_cache); STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type,
schema::MetaGraphT *graph, TensorCache *tensor_cache, int *index);
TensorCache *tensor_cache);
STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type,
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache, int *index);
schema::MetaGraphT *graph,
TensorCache *tensor_cache); STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache);
STATUS AddValueInfo(const onnx::ValueInfoProto &proto,
const std::string &name, void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const TensorType &type, schema::MetaGraphT *graph, TensorCache *tensor_cache);
TensorCache *tensor_cache,
int *index); STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
STATUS AddTensorProto(const onnx::TensorProto &proto, STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const std::string &name, const string &onnx_op_type, schema::CNodeT *dst_op);
const TensorType &type,
TensorCache *tensor_cache, void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op,
int *index); schema::TensorT *dst_tensor, TensorCache *tensor_cache);
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor, STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache);
TensorCache *tensor_cache);
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef);
schema::MetaGraphT *graph,
TensorCache *tensor_cache);
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache);
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
const string &onnx_op_type,
schema::CNodeT *dst_op);
void SetOpQuantParams(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache);
STATUS SetOpInputIndex(const std::vector<string> &node_inputs,
schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache);
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs,
schema::CNodeT *dst_op,
TensorCache *tensor_cache);
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value,
schema::TensorT *tensor);
STATUS SetAllTensors(const TensorCache &tensor_cache,
schema::MetaGraphT *graphDef);
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册