diff --git a/mindspore/ccsrc/transform/onnx/ir_exporter.cc b/mindspore/ccsrc/transform/onnx/ir_exporter.cc index 57c07c6889051dfa558a13c8cf7ac165f1b87a78..a6b7cbe38bd60eff271704b355ac3b111fa76c7d 100644 --- a/mindspore/ccsrc/transform/onnx/ir_exporter.cc +++ b/mindspore/ccsrc/transform/onnx/ir_exporter.cc @@ -90,14 +90,17 @@ class IrExportBuilder { void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, - std::string suffix = "0"); + void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto, + std::string *const seq_string); void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); - void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); + void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name); + void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, + std::string *const seq_string); + void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, + std::string *const seq_string); onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); @@ -105,8 +108,10 @@ class IrExportBuilder { std::string GetNodeName(const AnfNodePtr &node); std::string GetUniqueNodeName(const AnfNodePtr &node); std::string GetOpTypeName(const AnfNodePtr &node); - size_t AllocateIndex() { return ++node_index_; } - void ResetIndex() { node_index_ = 0; } + size_t GetNodeIndex() { return ++node_index_; } + void ResetNodeIndex() { node_index_ = 0; } + size_t GetTupleIndex() { return ++shape_index_; } + void ResetTupleIndex() { shape_index_ = 0; } private: onnx::ModelProto model_; @@ -114,6 +119,7 @@ class IrExportBuilder { std::list todo_; std::map node_index_map_; size_t node_index_{0}; + size_t shape_index_{0}; }; using IrExporterPtr = std::shared_ptr; @@ -146,7 +152,7 @@ void IrExportBuilder::BuildModelInfo() { void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { onnx::GraphProto *graph_proto = model_.mutable_graph(); graph_proto->set_name(func_graph->ToString()); - ResetIndex(); + ResetNodeIndex(); todo_.clear(); todo_.push_back(func_graph); while (!todo_.empty()) { @@ -177,7 +183,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap input_proto->set_name(param_name); SetValueInfoProto(param, input_proto); if (!param->has_default()) { - MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; + MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default."; continue; } @@ -232,13 +238,20 @@ void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr auto elem_type = tensor->element(); const auto &dims = shape->cast()->shape(); type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); - for (const auto &dim : dims) { - MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + if (dims.size() == 0) { + MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1."; + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + } else { + for (const auto &dim : dims) { + MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } } } else if (type->isa()) { auto tup_shape = shape->cast(); - type_proto->set_denotation(std::to_string(tup_shape->shape().size())); + type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); + } else if (type->isa() || type->isa()) { + type_proto->set_denotation(type->type_name()); } else { MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; } @@ -248,9 +261,10 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } - attr_proto->set_ref_attr_name("tensor"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + attr_proto->set_ref_attr_name("tensor:value0"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); + onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + tensor_proto->set_name("value0"); auto data = value->cast(); tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); auto dtype = data->data_type(); @@ -284,6 +298,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); + bool is_only_return = true; for (const AnfNodePtr &node : nodes) { if (!node->isa()) { MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; @@ -291,9 +306,13 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt } auto cnode = node->cast(); if (cnode == func_graph->get_return()) { + if (is_only_return) { + MS_LOG(EXCEPTION) << "Only has return node, can't convert to binary model!"; + } BuildOutput(cnode, graph_proto); } else { BuildCNode(cnode, graph_proto); + is_only_return = false; } } } @@ -303,24 +322,11 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; } AnfNodePtr arg = node->input(1); - // Using make_tuple to set multi-output - if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { - auto tuple_node = arg->cast(); - for (size_t i = 1; i < tuple_node->size(); i++) { - auto input_node = arg->cast()->input(i); - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - auto output_name = GetUniqueNodeName(tuple_node->input(i)); - output_proto->set_name(output_name); - last_node_->add_output(output_name); - SetValueInfoProto(tuple_node->input(i), output_proto); - } - } else { - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - std::string output_name = GetUniqueNodeName(node); - output_proto->set_name(output_name); - last_node_->add_output(output_name); - SetValueInfoProto(arg, output_proto); - } + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + std::string output_name = GetUniqueNodeName(node); + output_proto->set_name(output_name); + last_node_->set_output(0, output_name); + SetValueInfoProto(arg, output_proto); } std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { @@ -343,45 +349,44 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { } void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::NodeProto *const node_proto, std::string suffix) { - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_ref_attr_name("shape"); - if (suffix.compare("0") != 0) { - attr_proto->set_name("shape" + suffix); - } else { - attr_proto->set_name("shape"); - } - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - SetTensorProto(type, shape, tensor_proto); -} - -void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { - // Get shape of cnode - // 1. prim ArgMaxWithValue need to get shape from tuple element - // 2. some cnode doesn't has shape, such as LayerNorm - // 3. other cnodes have shape - if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { - auto type = node->Type(); - auto shape = node->Shape(); - if (!type->isa()) { - MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); - } + onnx::AttributeProto *const attr_proto, std::string *const seq_string) { + if (type->isa() && seq_string != nullptr) { + *seq_string += "Tuple["; auto elements = type->cast()->elements(); auto tuple_shape = shape->cast()->shape(); for (size_t i = 0; i < elements.size(); i++) { - SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); + SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string); } + *seq_string += "],"; + } else if (type->isa() && shape->isa() && seq_string != nullptr) { + string shape_name = "shape" + std::to_string(GetTupleIndex()); + *seq_string += shape_name + ","; + onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + tensor_proto->set_name(shape_name); + SetTensorProto(type, shape, tensor_proto); + } else if ((type->isa() || type->isa()) && seq_string != nullptr) { + *seq_string += type->type_name() + ","; } else { - auto type = node->Type(); - auto shape = node->Shape(); - if (!type->isa() || !shape->isa()) { - MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); - return; - } - SetShapeToNodeProto(type, shape, node_proto); + MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name(); } } +void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { + // Get shape of cnode + // 1. need to get shape from tuple element + // 2. save shape in TensorProto + // 3. save tuple string in ref_attr_name + MS_EXCEPTION_IF_NULL(node); + auto type = node->Type(); + auto shape = node->Shape(); + ResetTupleIndex(); + std::string seq_string = "shape:"; + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + SetShapeToNodeProto(type, shape, attr_proto, &seq_string); + attr_proto->set_ref_attr_name(seq_string); + MS_LOG(DEBUG) << "CNode shape: " << seq_string; +} + void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { auto inputs_size = node->size(); if (inputs_size < 1) { @@ -443,15 +448,19 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { std::string node_name = ""; if (node->isa()) { node_name = GetNodeName(node); - } else if (node->isa() || node->isa()) { + } else if (node->isa()) { auto iter = node_index_map_.find(node); if (iter != node_index_map_.end()) { node_name = GetNodeName(node) + ":" + std::to_string(iter->second); } else { - auto node_idx = AllocateIndex(); + auto node_idx = GetNodeIndex(); node_index_map_[node] = node_idx; node_name = GetNodeName(node) + ":" + std::to_string(node_idx); } + } else if (node->isa()) { + auto node_idx = GetNodeIndex(); + node_index_map_[node] = node_idx; + node_name = GetNodeName(node) + ":" + std::to_string(node_idx); } else { MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); } @@ -485,17 +494,21 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } - attr_proto->set_ref_attr_name("type"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); + onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); if (value->isa()) { + attr_proto->set_ref_attr_name("type:value0"); + tensor_proto->set_name("value0"); auto int_value = value->cast(); tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); } else if (value->isa()) { + attr_proto->set_ref_attr_name("type:value0"); + tensor_proto->set_name("value0"); auto float_value = value->cast(); tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); } else if (value->isa()) { - tensor_proto->set_name("tensor"); + attr_proto->set_ref_attr_name("type:tensor0"); + tensor_proto->set_name("tensor0"); auto elem_type = value->cast()->element(); if (elem_type->isa()) { auto int_value = elem_type->cast(); @@ -519,10 +532,18 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr SetScalarToAttributeProto(value, attr_proto); } else if (value->isa() || value->isa()) { SetTypeToAttributeProto(value, attr_proto); - } else if (value->isa()) { - SetSequenceToAttributeProto(value->cast(), attr_proto); + } else if (value->isa() || value->isa()) { + ResetTupleIndex(); + std::string seq_string = "scalar:"; + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); + SetSequenceToAttributeProto(value->cast(), attr_proto, &seq_string); + attr_proto->set_ref_attr_name(seq_string); + MS_LOG(DEBUG) << "Attr string: " << seq_string; } else if (value->isa()) { SetTensorToAttributeProto(value, attr_proto); + } else if (value->isa()) { + attr_proto->set_ref_attr_name("none"); + MS_LOG(DEBUG) << "Attr string: " << value->type_name(); } else { MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); } @@ -532,16 +553,18 @@ void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::Att if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } - attr_proto->set_ref_attr_name("scalar"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - SetScalarToProto(value, tensor_proto); + attr_proto->set_ref_attr_name("scalar:value0"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); + onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + SetScalarToProto(value, tensor_proto, "value0"); } -void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { +void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, + const std::string &value_name) { if (value == nullptr || tensor_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; } + tensor_proto->set_name(value_name); if (value->isa()) { tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); tensor_proto->add_string_data(GetValue(value)); @@ -560,44 +583,74 @@ void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto } else if (value->isa()) { tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); tensor_proto->add_int64_data(value->cast()->value()); - } else if (value->isa()) { + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); + tensor_proto->add_uint64_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); + tensor_proto->add_uint64_data(value->cast()->value()); + } else if (value->isa()) { tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); tensor_proto->add_float_data(GetValue(value)); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); + tensor_proto->add_double_data(GetValue(value)); } else { MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); } } -void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, - onnx::AttributeProto *const attr_proto) { +void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, + std::string *const seq_string) { + string value_name = "value" + std::to_string(GetTupleIndex()); + if (seq_string != nullptr) { + *seq_string += value_name + ","; + } + onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + SetScalarToProto(value, tensor_proto, value_name); +} + +void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, + std::string *const seq_string) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; } - attr_proto->set_ref_attr_name("scalar"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - if (value->isa()) { + if (value->isa() && seq_string != nullptr) { + *seq_string += "Tuple["; const ValueTuplePtr &tuple_value = value->cast(); if (tuple_value->value().size() == 0) { MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; return; } - auto type_id = tuple_value->value()[0]->type()->type_id(); - tensor_proto->set_data_type(GetOnnxDataType(type_id)); for (const auto &item : tuple_value->value()) { - SetScalarToProto(item, tensor_proto); + if (item->isa()) { + SetSequenceToAttributeProto(item->cast(), attr_proto, seq_string); + } else { + SetSeqElemToAttributeProto(item, attr_proto, seq_string); + } } - } else if (value->isa()) { + *seq_string += "],"; + } else if (value->isa() && seq_string != nullptr) { + *seq_string += "List["; const ValueListPtr &list_value = value->cast(); if (list_value->value().size() == 0) { - MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; + MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; return; } - auto type_id = list_value->value()[0]->type()->type_id(); - tensor_proto->set_data_type(GetOnnxDataType(type_id)); for (const auto &item : list_value->value()) { - SetScalarToProto(item, tensor_proto); + if (item->isa()) { + SetSequenceToAttributeProto(item->cast(), attr_proto, seq_string); + } else { + SetSeqElemToAttributeProto(item, attr_proto, seq_string); + } } + *seq_string += "],"; } } diff --git a/mindspore/ccsrc/utils/load_onnx/anf_converter.cc b/mindspore/ccsrc/utils/load_onnx/anf_converter.cc index d5626d14c14025bf97a7e856f7c0f8b25815b3c1..cce4e813d9576dc66baaa2fb07fecf9d8ecd7934 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_converter.cc +++ b/mindspore/ccsrc/utils/load_onnx/anf_converter.cc @@ -57,7 +57,7 @@ int AnfConverter::ValidateFileStr(const std::string &modelFile, std::string file bool AnfConverter::ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) { std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); - int fd = open(onnx_file.get(), O_RDONLY); + int fd = open(modelFile.c_str(), O_RDONLY); if (fd < 0) { MS_LOG(EXCEPTION) << "failed to open file"; } diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc index fe16f88be65d3dc3e6a2ec12ca88832beeaa9c1f..6e38d72edec751e6e8bd5ac3163b26233e0f61e6 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc +++ b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc @@ -18,8 +18,12 @@ #include #include #include +#include #include #include +#include +#include +#include "google/protobuf/io/zero_copy_stream_impl.h" #include "ir/tensor.h" #include "ir/param_info.h" #include "frontend/operator/ops.h" @@ -55,6 +59,97 @@ static std::unordered_map kDefaultValueSwitchMap{ {onnx::TensorProto_DataType_STRING, kObjectTypeString}, }; +template +std::shared_ptr ParserAttr(const std::string &str, const std::unordered_map &kv) { + std::stack rules; + std::stack

value; + int count = 0; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == '[') { + rules.push("["); + } else if (str[i] == ']') { + // rules + std::vector

vec; + while (rules.top() != "[") { + rules.pop(); + vec.push_back(value.top()); + value.pop(); + } + // pop "[" + rules.pop(); + // make tuple for names + std::string res = "dummy"; + // make tuple for values + reverse(vec.begin(), vec.end()); + auto vt = std::make_shared(vec); + if (rules.empty() && value.empty()) { + return vt; + } + rules.push(res); + value.push(vt); + } else if (str[i] == ',') { + continue; + } else { + count++; + if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { + auto value_name = str.substr(i - count + 1, count); + value.push(kv.at(value_name)); + rules.push(value_name); + count = 0; + } + } + } + return {}; +} + +std::shared_ptr ParserScalarAttrValue(const std::string &attr_name, + const std::unordered_map &kv) { + std::string str = attr_name; + auto replace = [&](const string &orgStr, const string &newStr) { + std::string::size_type pos(0); + while ((pos = str.find(orgStr)) != std::string::npos) { + str.replace(pos, orgStr.length(), newStr); + } + return str; + }; + // remove "scalar:" + str = replace("scalar:", ""); + // remove "Tuple" + str = replace("Tuple", ""); + // remove "List" + str = replace("List", ""); + auto result = ParserAttr(str, kv); + if (!result) { + return {}; + } + return result; +} + +std::shared_ptr ParserAttrShape( + const std::string &attr_name, const std::unordered_map &kv) { + std::string str = attr_name; + auto replace = [&](const string &orgStr, const string &newStr) { + std::string::size_type pos(0); + while ((pos = str.find(orgStr)) != std::string::npos) { + str.replace(pos, orgStr.length(), newStr); + } + return str; + }; + // remove "scalar:" + str = replace("shape:", ""); + // remove "Tuple" + str = replace("Tuple", ""); + // remove "List" + str = replace("List", ""); + + auto result = ParserAttr(str, kv); + if (!result) { + return {}; + } + return result; +} + +#if 0 #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ const onnx::TensorProto &attr_tensor) { \ @@ -67,9 +162,16 @@ static std::unordered_map kDefaultValueSwitchMap{ if (attr_value_vec.size() == 1) { \ prim->AddAttr(attr_name, attr_value_vec[0]); \ } else { \ - prim->AddAttr(attr_name, std::make_shared(attr_value_vec)); \ + ParserScalarAttrValue(prim, attr_name, attr_value_vec); \ } \ } +#endif + +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ + ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ + auto value = static_cast(attr_tensor.type##_data(0)); \ + return MakeValue(value); \ + } PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) PARSE_ONNXATTR_IN_SCALAR_FORM(float, float) @@ -110,6 +212,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); MS_EXCEPTION_IF_NULL(tensor_info); + // tensor_info->MallocData(); auto tensor_abstract = tensor_info->ToAbstract(); MS_EXCEPTION_IF_NULL(tensor_abstract); node->set_abstract(tensor_abstract); @@ -167,45 +270,35 @@ bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const return true; } -bool MSANFModelParser::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { - MS_EXCEPTION_IF_NULL(prim); +ValuePtr MSANFModelParser::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); switch (attr_tensor_type) { case onnx::TensorProto_DataType_STRING: { - ParseAttrInScalar_string_string(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_string_string(attr_tensor); } case onnx::TensorProto_DataType_INT32: { - ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_int32_int32(attr_tensor); } case onnx::TensorProto_DataType_INT64: { - ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_int64_int64(attr_tensor); } case onnx::TensorProto_DataType_UINT64: { - ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_uint64_uint64(attr_tensor); } case onnx::TensorProto_DataType_FLOAT: { - ParseAttrInScalar_float_float(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_float_float(attr_tensor); } case onnx::TensorProto_DataType_DOUBLE: { - ParseAttrInScalar_double_double(prim, attr_name, attr_tensor); - break; + return ParseAttrInScalar_double_double(attr_tensor); } case onnx::TensorProto_DataType_BOOL: { - ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor); - auto value = prim->GetAttr(attr_name); - break; + return ParseAttrInScalar_int32_bool(attr_tensor); } default: MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; - return false; + return {}; } - return true; + return {}; } bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, @@ -223,21 +316,48 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx return false; } const std::string &ref_attr_name = attr_proto.ref_attr_name(); - const onnx::TensorProto &attr_tensor = attr_proto.t(); - switch (kParseTypeSwitchMap[ref_attr_name]) { - case FORM_PARSE_TYPE: { - return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); - } - case FORM_PARSE_SCALAR: { - return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor); + string type; + std::size_t pos(0); + if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("scalar:").length() - 1); + } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("type:").length() - 1); + } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + switch (kParseTypeSwitchMap[type]) { + case FORM_PARSE_TYPE: { + ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); + break; + } + case FORM_PARSE_SCALAR: { + auto res = ObtainCNodeAttrInScalarForm(attr_tensor); + kv.insert(std::pair(attr_tensor.name(), res)); + break; + } + case FORM_PARSE_TENSOR: { + ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); + break; + } + default: + MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + return false; } - case FORM_PARSE_TENSOR: { - return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); + } + + if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { + if (kv.size() == 1) { + auto iter = kv.begin(); + prim->AddAttr(attr_name, iter->second); + } else { + auto res = ParserScalarAttrValue(ref_attr_name, kv); + prim->AddAttr(attr_name, res); } - default: - MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; - return false; } + return true; } bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { @@ -247,6 +367,7 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node shape.push_back(attr_tensor.dims(i)); } tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); + // tensor_info->MallocData(); const std::string &tensor_buf = attr_tensor.raw_data(); auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); @@ -324,22 +445,58 @@ bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_n return true; } -bool MSANFModelParser::GetAttrValueForValueNode(const std::string &ref_attr_name, const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { - switch (kParseTypeSwitchMap[ref_attr_name]) { - case FORM_PARSE_SCALAR: { - return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); - } - case FORM_PARSE_TENSOR: { - return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); +bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name, + const onnx::AttributeProto &attr_proto) { + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + string type; + std::size_t pos(0); + if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("scalar:").length() - 1); + } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("type:").length() - 1); + } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + auto attr_name = attr_tensor.name(); + switch (kParseTypeSwitchMap[type]) { + case FORM_PARSE_TYPE: { + return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); + } + case FORM_PARSE_SCALAR: { + auto res = ObtainCNodeAttrInScalarForm(attr_tensor); + kv.insert(std::pair(attr_tensor.name(), res)); + break; + } + case FORM_PARSE_TENSOR: { + return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); + } + default: + MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + return false; } - case FORM_PARSE_TYPE: { - return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); + } + + ValueNodePtr new_value_node; + if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { + if (kv.size() == 1) { + auto iter = kv.begin(); + new_value_node = NewValueNode(iter->second); + new_value_node->set_abstract(iter->second->ToAbstract()); + } else { + auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv); + new_value_node = NewValueNode(value_ptr); + new_value_node->set_abstract(value_ptr->ToAbstract()); } - default: - MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; - return false; + anfnode_build_map_[value_node_name] = new_value_node; } + return true; } bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { @@ -349,24 +506,26 @@ bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_pr MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; return false; } - const std::string &ref_attr_name = attr_proto.ref_attr_name(); - const onnx::TensorProto &attr_tensor = attr_proto.t(); - - return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); + return GetAttrValueForValueNode(value_node_name, attr_proto); } -AbstractBasePtr MSANFModelParser::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { - ShapeVector shape_vec; - const onnx::TensorProto &attr_tensor = attr_proto.t(); - for (int i = 0; i < attr_tensor.dims_size(); ++i) { - shape_vec.push_back(attr_tensor.dims(i)); - } - tensor::TensorPtr tensor_info = - std::make_shared(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec); - MS_EXCEPTION_IF_NULL(tensor_info); - auto abstract = tensor_info->ToAbstract(); - MS_EXCEPTION_IF_NULL(abstract); - return abstract; +std::unordered_map MSANFModelParser::GetAbstractForCNode( + const onnx::AttributeProto &attr_proto) { + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); ++i) { + ShapeVector shape_vec; + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + for (int j = 0; j < attr_tensor.dims_size(); ++j) { + shape_vec.push_back(attr_tensor.dims(j)); + } + tensor::TensorPtr tensor_info = + std::make_shared(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec); + MS_EXCEPTION_IF_NULL(tensor_info); + auto abstract = tensor_info->ToAbstract(); + MS_EXCEPTION_IF_NULL(abstract); + kv.insert(std::pair(attr_tensor.name(), abstract)); + } + return kv; } CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, @@ -383,21 +542,13 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc MS_EXCEPTION_IF_NULL(prim); prim->set_instance_name(node_type); - AbstractBasePtr abstract = nullptr; - AbstractBasePtr abstract_first = nullptr; - AbstractBasePtr abstract_second = nullptr; + std::unordered_map kv; + string shape_ref_attr_name; for (int i = 0; i < node_proto.attribute_size(); ++i) { const onnx::AttributeProto &attr_proto = node_proto.attribute(i); - if (attr_proto.name() == kCNodeShapeAttr) { - abstract = GetAbstractForCNode(attr_proto); - continue; - } - if (attr_proto.name() == kCNodeShape1Attr) { - abstract_first = GetAbstractForCNode(attr_proto); - continue; - } - if (attr_proto.name() == kCNodeShape2Attr) { - abstract_second = GetAbstractForCNode(attr_proto); + if (attr_proto.ref_attr_name().find("shape:") != string::npos) { + shape_ref_attr_name = attr_proto.ref_attr_name(); + kv = GetAbstractForCNode(attr_proto); continue; } if (!GetAttrValueForCNode(prim, attr_proto)) { @@ -419,24 +570,17 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc } CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(cnode_ptr); - if (node_type == "LayerNorm") { - AbstractBasePtrList elem; - elem.push_back(abstract); - elem.push_back(abstract_first); - elem.push_back(abstract_second); - cnode_ptr->set_abstract(std::make_shared(elem)); - } else if (node_type == "ArgMaxWithValue") { - AbstractBasePtrList elem; - elem.push_back(abstract); - elem.push_back(abstract_first); - cnode_ptr->set_abstract(std::make_shared(elem)); - } else if (nullptr == abstract) { + if (0 == kv.size()) { AbstractBasePtrList elem; for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { elem.push_back(cnode_ptr->input(index)->abstract()); } cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (1 == kv.size()) { + std::unordered_map::iterator iter = kv.begin(); + cnode_ptr->set_abstract(iter->second); } else { + auto abstract = ParserAttrShape(shape_ref_attr_name, kv); cnode_ptr->set_abstract(abstract); } cnode_ptr->set_fullname_with_scope(fullname_with_scope); @@ -471,19 +615,15 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra } else { const onnx::ValueInfoProto &output_node = importProto.output(0); const onnx::TypeProto &output_typeproto = output_node.type(); - int output_type = output_typeproto.tensor_type().elem_type(); ShapeVector output_shape; for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); } - tensor::TensorPtr tensor_return = - std::make_shared(kDefaultValueSwitchMap[output_type], output_shape); inputs.clear(); inputs.push_back(NewValueNode(prim::kPrimReturn)); inputs.push_back(cnode_ptr); auto return_node = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(return_node); - return_node->set_abstract(tensor_return->ToAbstract()); outputFuncGraph->set_return(return_node); MS_LOG(INFO) << "Construct funcgraph finined, all success!"; } diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h index 25152687f9e09191f3299d4c6a04e063b1aa3d14..5dc0b17b3567128baab619532ac6be430a846ebb 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h +++ b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h @@ -52,18 +52,17 @@ class MSANFModelParser { bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, const onnx::TensorProto &attr_tensor); - bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor); + ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor); bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, const onnx::TensorProto &attr_tensor); bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name, - const onnx::TensorProto &attr_tensor); + bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_tensor); bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - AbstractBasePtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); + std::unordered_map GetAbstractForCNode( + const onnx::AttributeProto &attr_proto); std::string producer_name_; int model_version_;