diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 36fdd5a300f851eec52678db3f90c35dbc7a161b..781f5b1623660a6196974015236deac361b83f20 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -1566,7 +1566,7 @@ class IrParser { return lexer_.GetNextToken(); } else if (type == "Tuple") { return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr); - } else if (type == "Array") { + } else if (type == "Tensor") { return ParseTypeArray(func_graph, lexer_.GetNextToken(), ptr); } else if (type == "List") { return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr); diff --git a/mindspore/ccsrc/debug/anf_ir_utils.h b/mindspore/ccsrc/debug/anf_ir_utils.h index 9601e9d87aa3f120fea65e0d08403f5c926509e1..4503692eb96e54abf3f8752804bc3efaec7ec97f 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.h +++ b/mindspore/ccsrc/debug/anf_ir_utils.h @@ -118,6 +118,8 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); + +std::string GetBinaryProtoString(const FuncGraphPtr &func_graph); } // namespace mindspore #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ diff --git a/mindspore/ccsrc/onnx/ir_exporter.cc b/mindspore/ccsrc/onnx/ir_exporter.cc new file mode 100644 index 0000000000000000000000000000000000000000..687d7c23e2c578c81906f7edfc7ed8672489529c --- /dev/null +++ b/mindspore/ccsrc/onnx/ir_exporter.cc @@ -0,0 +1,631 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/param_value_py.h" +#include "debug/anf_ir_utils.h" +#include "operator/ops.h" +#include "proto/onnx.pb.h" + +namespace mindspore { + +using FloatPtr = std::shared_ptr; +using IntPtr = std::shared_ptr; + +// anf type to onnx type map +static std::unordered_map g_data_type_map = { + {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, + {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, + {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, + {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, + {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, + {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, + {kObjectTypeString, onnx::TensorProto_DataType_STRING}, +}; + +static std::unordered_map g_data_bits_int_map = { + {8, onnx::TensorProto_DataType_INT8}, + {16, onnx::TensorProto_DataType_INT16}, + {32, onnx::TensorProto_DataType_INT32}, + {64, onnx::TensorProto_DataType_INT64}, +}; + +static std::unordered_map g_data_bits_float_map = { + {16, onnx::TensorProto_DataType_FLOAT16}, + {32, onnx::TensorProto_DataType_FLOAT}, +}; + +// Can build different builder according to format +class IrExportBuilder; +using IrExportBuilderPtr = std::shared_ptr; + +class IrExporter { + public: + explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {} + virtual ~IrExporter() = default; + std::string GetDumpString(const FuncGraphPtr &func_graph); + + private: + IrExportBuilderPtr builder_; +}; + +class IrExportBuilder { + public: + IrExportBuilder() = default; + ~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); } + std::string GetProtoString(const FuncGraphPtr &func_graph); + void BuildModelInfo(); + void BuildModel(const FuncGraphPtr &func_graph); + + private: + void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); + void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); + void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); + void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto); + void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto); + std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto); + + void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto); + void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto); + void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto); + void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); + void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); + void SetShapeToNodeProto(const CNodePtr &node, const std::vector &inputs, + onnx::NodeProto *const node_proto); + void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto); + void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); + void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); + + onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); + onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); + onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits); + std::string GetNodeName(const AnfNodePtr &node); + std::string GetUniqueNodeName(const AnfNodePtr &node); + std::string GetOpTypeName(const AnfNodePtr &node); + size_t AllocateIndex() { return ++node_index_; } + void ResetIndex() { node_index_ = 0; } + + private: + onnx::ModelProto model_; + onnx::NodeProto *last_node_; + std::list todo_; + std::map node_index_map_; + size_t node_index_ = 0; +}; + +using IrExporterPtr = std::shared_ptr; + +std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) { + if ((builder_ == nullptr) || (func_graph == nullptr)) { + MS_LOG(EXCEPTION) << "Input params is null."; + } + + // Export model info + builder_->BuildModelInfo(); + + // Export model and return string + builder_->BuildModel(func_graph); + + return builder_->GetProtoString(func_graph); +} + +std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) { + MS_LOG(DEBUG) << "BuildModel complete!"; + return model_.SerializeAsString(); +} + +void IrExportBuilder::BuildModelInfo() { + model_.set_ir_version(onnx::IR_VERSION_2019_1_22); + model_.set_producer_name("MindSpore"); + model_.set_model_version(1); +} + +void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { + onnx::GraphProto *graph_proto = model_.mutable_graph(); + graph_proto->set_name(func_graph->ToString()); + ResetIndex(); + todo_.clear(); + todo_.push_back(func_graph); + while (!todo_.empty()) { + FuncGraphPtr fg = todo_.back(); + todo_.pop_back(); + BuildFuncGraph(fg, graph_proto); + } +} + +void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + // Export parameters + // 1. parameters should be mapped to ValueInfoProto + // 2. parameters with default value should be mapped to Initializer + BuildParameters(func_graph, graph_proto); + + // Export operator nodes(include output) + BuildNodes(func_graph, graph_proto); +} + +void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + for (auto &item : func_graph->parameters()) { + auto param = item->cast(); + if (param == nullptr) { + MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; + } + onnx::ValueInfoProto *input_proto = graph_proto->add_input(); + std::string param_name = GetUniqueNodeName(param); + input_proto->set_name(param_name); + SetValueInfoProto(param, input_proto); + if (!param->has_default()) { + MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; + continue; + } + + // Using ONNX initializer to set parameter's default value + onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); + initializer_proto->set_name(param_name); + SetParamToTensorProto(param, initializer_proto); + auto param_value = std::dynamic_pointer_cast(param->default_param()); + py::object obj = param_value->value(); + py::object data = obj.attr("data"); + if (py::isinstance(data)) { + auto method = data.attr("asnumpy"); + py::array npy_data = method(); + initializer_proto->set_raw_data(npy_data.request(true).ptr, static_cast(npy_data.nbytes())); + } + } +} + +onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) { + auto iter = g_data_type_map.find(type_id); + if (iter == g_data_type_map.end()) { + MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id; + } + return iter->second; +} + +onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) { + auto iter = g_data_bits_int_map.find(bits); + if (iter == g_data_bits_int_map.end()) { + MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits; + } + return iter->second; +} + +onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) { + auto iter = g_data_bits_float_map.find(bits); + if (iter == g_data_bits_float_map.end()) { + MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits; + } + return iter->second; +} + +void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) { + if (node == nullptr || value_proto == nullptr) { + MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!"; + } + MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); + SetValueInfoProto(node->Type(), node->Shape(), value_proto); +} + +void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, + onnx::ValueInfoProto *const value_proto) { + onnx::TypeProto *type_proto = value_proto->mutable_type(); + if (type->isa() && shape->isa()) { + auto tensor = type->cast(); + auto elem_type = tensor->element(); + const auto &dims = shape->cast()->shape(); + type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); + for (const auto &dim : dims) { + MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } + } else if (type->isa()) { + auto tup_shape = shape->cast(); + type_proto->set_denotation(std::to_string(tup_shape->shape().size())); + } else { + MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; + } +} + +void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("tensor"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + auto data = value->cast(); + tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast(data->data().nbytes())); + auto dtype = data->data_type(); + auto shape = data->shape_c(); + tensor_proto->set_data_type(GetOnnxDataType(dtype)); + for (const auto &dim : shape) { + tensor_proto->add_dims(dim); + } +} + +void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, + onnx::TensorProto *const tensor_proto) { + if (!type->isa() || !shape->isa()) { + MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString(); + } + auto tensor = type->cast(); + const auto &dims = shape->cast()->shape(); + tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id())); + for (const auto &dim : dims) { + tensor_proto->add_dims(dim); + } +} + +void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { + if (param == nullptr || tensor_proto == nullptr) { + MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; + } + MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString(); + SetTensorProto(param->Type(), param->Shape(), tensor_proto); +} + +void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); + for (const AnfNodePtr &node : nodes) { + if (!node->isa()) { + MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; + continue; + } + auto cnode = node->cast(); + if (cnode == func_graph->get_return()) { + BuildOutput(cnode, graph_proto); + } else { + BuildCNode(cnode, graph_proto); + } + } +} + +void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) { + if (node->size() != 2) { + MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; + } + AnfNodePtr arg = node->input(1); + // Using make_tuple to set multi-output + if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { + auto tuple_node = arg->cast(); + for (size_t i = 1; i < tuple_node->size(); i++) { + auto input_node = arg->cast()->input(i); + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + auto output_name = GetUniqueNodeName(tuple_node->input(i)); + output_proto->set_name(output_name); + last_node_->add_output(output_name); + SetValueInfoProto(tuple_node->input(i), output_proto); + } + } else { + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + std::string output_name = GetUniqueNodeName(node); + output_proto->set_name(output_name); + last_node_->add_output(output_name); + SetValueInfoProto(arg, output_proto); + } +} + +std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { + // May be ValueNode/CNode/Parameter + std::string type_name = ""; + if (IsValueNode(node)) { + PrimitivePtr prim = GetValueNode(node); + type_name = prim->ToString(); + } else if (IsValueNode(node)) { + FuncGraphPtr fg = GetValueNode(node); + todo_.push_back(fg); + type_name = fg->ToString(); + } else if (node->isa() || node->isa()) { + type_name = node->ToString(); + } else { + MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name(); + } + MS_LOG(DEBUG) << "ExportType: " << type_name; + return type_name; +} + +void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, + onnx::NodeProto *const node_proto) { + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_ref_attr_name("shape"); + attr_proto->set_name("shape"); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + SetTensorProto(type, shape, tensor_proto); +} + +void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, const std::vector &inputs, + onnx::NodeProto *const node_proto) { + // Get shape of cnode + // 1. prim kPrimTupleGetItem need to get shape of input node according to the index + // 2. some cnode doesn't has shape, such as LayerNorm + // 3. other cnodes have shape + if (node->IsApply(prim::kPrimTupleGetItem)) { + // Get index of tuple get_item + int index_pos = inputs.size() - 1; + if (!inputs[index_pos]->isa()) { + MS_LOG(EXCEPTION) << "Index is not ValueNode: " << index_pos; + } + auto value = inputs[index_pos]->cast()->value(); + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Index type is not supported: " << value->type_name(); + } + size_t index = GetValue(value); + + // Get type and shape of input node + auto tup_type = inputs[0]->Type(); + if (!tup_type->isa()) { + MS_LOG(EXCEPTION) << "Input data of kPrimTupleGetItem cnode must be tuple: " << tup_type->type_name(); + } + auto type = tup_type->cast()->elements()[index]; + auto tup_shape = inputs[0]->Shape()->cast(); + if (index >= tup_shape->shape().size()) { + MS_LOG(EXCEPTION) << "Index exceed upper limit: " << tup_shape->shape().size(); + } + auto shape = tup_shape->shape()[index]; + SetShapeToNodeProto(type, shape, node_proto); + } else { + auto type = node->Type(); + auto shape = node->Shape(); + if (!type->isa() || !shape->isa()) { + MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); + return; + } + SetShapeToNodeProto(type, shape, node_proto); + } +} + +void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { + auto inputs_size = node->size(); + if (inputs_size < 1) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + + // Need to build input node before dealing with cnode + std::vector op_inputs; + std::vector input_names; + for (size_t i = 1; i < inputs_size; i++) { + auto input = node->input(i); + op_inputs.push_back(input); + input_names.push_back(BuildInputNode(input, graph_proto)); + } + + // Build cnode + onnx::NodeProto *node_proto = graph_proto->add_node(); + std::string output_name = GetUniqueNodeName(node); + node_proto->add_output(output_name); + node_proto->set_name(output_name); + AnfNodePtr op = node->input(0); + std::string type_name = GetOpTypeName(op); + node_proto->set_op_type(type_name); + last_node_ = node_proto; + SetShapeToNodeProto(node, op_inputs, node_proto); + (void)std::for_each(input_names.begin(), input_names.end(), + [&node_proto](const string &name) { node_proto->add_input(name); }); + + // Add primitive attrs + if (IsValueNode(op)) { + auto prim = GetValueNode(op); + for (auto attr : prim->attrs()) { + MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name(attr.first); + SetValueToAttributeProto(attr.second, attr_proto); + } + } else { + MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name(); + } +} + +std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) { + std::string node_name = GetUniqueNodeName(node); + if (node->isa()) { + // When node input is a ValueNode, need to create a Constant Node + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->add_output(node_name); + SetAttributeProto(node, node_proto); + } + return node_name; +} + +std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { + // Naming anfnode + // 1. parameter is unique in one func_graph + // 2. cnode and valuenode may be reduplicative, so add index to identify. + std::string node_name = ""; + if (node->isa()) { + node_name = GetNodeName(node); + } else if (node->isa() || node->isa()) { + auto iter = node_index_map_.find(node); + if (iter != node_index_map_.end()) { + node_name = GetNodeName(node) + ":" + std::to_string(iter->second); + } else { + auto node_idx = AllocateIndex(); + node_index_map_[node] = node_idx; + node_name = GetNodeName(node) + ":" + std::to_string(node_idx); + } + } else { + MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); + } + MS_LOG(DEBUG) << "Node name: " << node_name; + return node_name; +} + +std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { + std::string node_name = ""; + if ((node != nullptr) && (node->func_graph() != nullptr)) { + node_name = node->func_graph()->ToString() + ":"; + } + node_name += node->ToString(); + MS_LOG(DEBUG) << "GetNodeName: " << node_name; + return node_name; +} + +void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) { + if (node == nullptr || node_proto == nullptr) { + MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; + } + auto value = node->cast()->value(); + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("value"); + MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString(); + SetValueToAttributeProto(value, attr_proto); +} + +void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("type"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + if (value->isa()) { + auto int_value = value->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); + } else if (value->isa()) { + auto float_value = value->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); + } else if (value->isa()) { + tensor_proto->set_name("tensor"); + auto elem_type = value->cast()->element(); + if (elem_type->isa()) { + auto int_value = elem_type->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); + } else if (elem_type->isa()) { + auto float_value = elem_type->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); + } else { + MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name(); + } + } else { + MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); + } +} + +void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + if (value->isa() || value->isa()) { + SetScalarToAttributeProto(value, attr_proto); + } else if (value->isa() || value->isa()) { + SetTypeToAttributeProto(value, attr_proto); + } else if (value->isa()) { + SetSequenceToAttributeProto(value->cast(), attr_proto); + } else if (value->isa()) { + SetTensorToAttributeProto(value, attr_proto); + } else { + MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); + } +} + +void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("scalar"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + SetScalarToProto(value, tensor_proto); +} + +void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { + if (value == nullptr || tensor_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; + } + if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); + tensor_proto->add_string_data(GetValue(value)); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL); + tensor_proto->add_int32_data(GetValue(value)); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + tensor_proto->add_int64_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); + tensor_proto->add_float_data(GetValue(value)); + } else { + MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); + } +} + +void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, + onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("scalar"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + if (value->isa()) { + const ValueTuplePtr &tuple_value = value->cast(); + if (tuple_value->value().size() == 0) { + MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; + return; + } + auto type_id = tuple_value->value()[0]->type()->type_id(); + tensor_proto->set_data_type(GetOnnxDataType(type_id)); + for (const auto &item : tuple_value->value()) { + SetScalarToProto(item, tensor_proto); + } + } else if (value->isa()) { + const ValueListPtr &list_value = value->cast(); + if (list_value->value().size() == 0) { + MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; + return; + } + auto type_id = list_value->value()[0]->type()->type_id(); + tensor_proto->set_data_type(GetOnnxDataType(type_id)); + for (const auto &item : list_value->value()) { + SetScalarToProto(item, tensor_proto); + } + } +} + +std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { + auto builder = std::make_shared(); + if (builder == nullptr) { + MS_LOG(ERROR) << "Create ir exporter failed!"; + return ""; + } + auto exporter = std::make_shared(builder); + if (exporter == nullptr) { + return ""; + } + return exporter->GetDumpString(func_graph); +} +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index e154ad7f7ce0b82b29ecd36810f8d3104ed67547..bc204d2ec8bd03da8e987f516ea136905a6c9dd1 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -59,6 +59,7 @@ using mindspore::abstract::AbstractTuplePtr; const char IR_TYPE_ANF[] = "anf_ir"; const char IR_TYPE_ONNX[] = "onnx_ir"; +const char IR_TYPE_BINARY[] = "binary_ir"; ExecutorPyPtr ExecutorPy::executor_ = nullptr; std::mutex ExecutorPy::instance_lock_; @@ -212,6 +213,14 @@ py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::str return proto_str; } + if (ir_type == IR_TYPE_BINARY) { + std::string proto_str = GetBinaryProtoString(fg_ptr); + if (proto_str.empty()) { + MS_LOG(EXCEPTION) << "Graph proto is empty."; + } + return proto_str; + } + MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type; } @@ -506,7 +515,6 @@ void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource, // when in loading anf ir mode, action `parse` do nothing if (action.first == "parse") { - parse::PythonAdapter::SetPythonEnvFlag(true); return; } @@ -566,6 +574,7 @@ void Pipeline::Run() { draw::Draw(base_name + ".dot", graph); // generate IR file in human readable format DumpIR(base_name + ".ir", graph); + // generate IR file in a heavily commented format, which can also be reloaded if (action.first != "parse") { ExportIR(base_name + ".dat", std::to_string(i), graph); diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 502f00572f24ddeac16c280915e3af9727ff2496..d09494d1897abfb899f9e21a16201c7b6cf0b711 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -398,17 +398,18 @@ def export(net, *inputs, file_name, file_format='GEIR'): net (Cell): MindSpore network. inputs (Tensor): Inputs of the `net`. file_name (str): File name of model to export. - file_format (str): MindSpore currently supports 'GEIR', 'ONNX' and 'LITE' format for exported model. + file_format (str): MindSpore currently supports 'GEIR', 'ONNX' 'LITE' and 'BINARY' format for exported model. - GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of Ascend model. - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models. - LITE: Huawei model format for mobile. A lite model only for the MindSpore Lite + - BINARY: Binary format for model. An intermidiate representation format for models. """ logger.info("exporting model file:%s format:%s.", file_name, file_format) check_input_data(*inputs, data_class=Tensor) - supported_formats = ['GEIR', 'ONNX', 'LITE'] + supported_formats = ['GEIR', 'ONNX', 'LITE', 'BINARY'] if file_format not in supported_formats: raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}') # switch network mode to infer when it is training @@ -428,6 +429,13 @@ def export(net, *inputs, file_name, file_format='GEIR'): with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) + elif file_format == 'BINARY': # file_format is 'BINARY' + phase_name = 'export_binary' + graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) + onnx_stream = _executor._get_func_graph_proto(graph_id, 'binary_ir') + with open(file_name, 'wb') as f: + os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) + f.write(onnx_stream) elif file_format == 'LITE': # file_format is 'LITE' context.set_context(save_ms_model=True, save_ms_model_path=file_name) net(*inputs) diff --git a/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc b/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc index 871fffc1c7ed996977a086f3f0a947876307428d..45b2f422eafd6bca667556b3b71ed1cbcee40038 100644 --- a/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc +++ b/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc @@ -17,8 +17,9 @@ namespace mindspore { -std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { return ""; } +std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; } -std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) { return ""; } +std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; } +std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; } } // namespace mindspore