提交 ef35e2d9 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2002 Add dump ir function in binary format

Merge pull request !2002 from leopz/test_dump
......@@ -1566,7 +1566,7 @@ class IrParser {
return lexer_.GetNextToken();
} else if (type == "Tuple") {
return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr);
} else if (type == "Array") {
} else if (type == "Tensor") {
return ParseTypeArray(func_graph, lexer_.GetNextToken(), ptr);
} else if (type == "List") {
return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr);
......
......@@ -118,6 +118,8 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_
此差异已折叠。
......@@ -59,6 +59,7 @@ using mindspore::abstract::AbstractTuplePtr;
const char IR_TYPE_ANF[] = "anf_ir";
const char IR_TYPE_ONNX[] = "onnx_ir";
const char IR_TYPE_BINARY[] = "binary_ir";
ExecutorPyPtr ExecutorPy::executor_ = nullptr;
std::mutex ExecutorPy::instance_lock_;
......@@ -212,6 +213,14 @@ py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::str
return proto_str;
}
if (ir_type == IR_TYPE_BINARY) {
std::string proto_str = GetBinaryProtoString(fg_ptr);
if (proto_str.empty()) {
MS_LOG(EXCEPTION) << "Graph proto is empty.";
}
return proto_str;
}
MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type;
}
......@@ -506,7 +515,6 @@ void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource,
// when in loading anf ir mode, action `parse` do nothing
if (action.first == "parse") {
parse::PythonAdapter::SetPythonEnvFlag(true);
return;
}
......@@ -566,6 +574,7 @@ void Pipeline::Run() {
draw::Draw(base_name + ".dot", graph);
// generate IR file in human readable format
DumpIR(base_name + ".ir", graph);
// generate IR file in a heavily commented format, which can also be reloaded
if (action.first != "parse") {
ExportIR(base_name + ".dat", std::to_string(i), graph);
......
......@@ -398,17 +398,18 @@ def export(net, *inputs, file_name, file_format='GEIR'):
net (Cell): MindSpore network.
inputs (Tensor): Inputs of the `net`.
file_name (str): File name of model to export.
file_format (str): MindSpore currently supports 'GEIR', 'ONNX' and 'LITE' format for exported model.
file_format (str): MindSpore currently supports 'GEIR', 'ONNX' 'LITE' and 'BINARY' format for exported model.
- GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
Ascend model.
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
- LITE: Huawei model format for mobile. A lite model only for the MindSpore Lite
- BINARY: Binary format for model. An intermidiate representation format for models.
"""
logger.info("exporting model file:%s format:%s.", file_name, file_format)
check_input_data(*inputs, data_class=Tensor)
supported_formats = ['GEIR', 'ONNX', 'LITE']
supported_formats = ['GEIR', 'ONNX', 'LITE', 'BINARY']
if file_format not in supported_formats:
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
# switch network mode to infer when it is training
......@@ -428,6 +429,13 @@ def export(net, *inputs, file_name, file_format='GEIR'):
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
elif file_format == 'BINARY': # file_format is 'BINARY'
phase_name = 'export_binary'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id, 'binary_ir')
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
elif file_format == 'LITE': # file_format is 'LITE'
context.set_context(save_ms_model=True, save_ms_model_path=file_name)
net(*inputs)
......
......@@ -17,8 +17,9 @@
namespace mindspore {
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { return ""; }
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) { return ""; }
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; }
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册