提交 4a001ece 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!607 optimize flow of export onnx model

Merge pull request !607 from fary86/optimize_flow_of_exporting_onnx_model
......@@ -294,6 +294,30 @@ void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
MS_LOG(INFO) << "End save compiled func graph!";
}
void ExecutorPy::SaveCompiledGraphToPb(const std::string &phase_s) {
#ifdef ENABLE_DUMP_IR
// save the graph to file in protobuf format
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
std::string name_prefix = phase_s.substr(0, phase_s.find("."));
std::string pb_filename = std::string("ms_output_") + name_prefix + ".pb";
std::string filename = GetFilePathName(pb_filename);
MS_LOG(INFO) << "Begin saving graph to file <<'" << filename << "' in protobuf formart.";
ChangeFileMode(filename, S_IRWXU);
std::ofstream ofs(filename);
if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
return;
}
ofs << GetFuncGraphProtoString(func_graph);
ofs.close();
// set file mode to read only by user
ChangeFileMode(filename, S_IRUSR);
MS_LOG(INFO) << "End saving graph to file in protobuf format";
#endif
}
bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const {
std::string phase_prefix = GetPhasePrefix(phase_s);
......@@ -365,6 +389,8 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
info_[phase_s] = executor_info;
pip->Run();
// save compile graph to file in protobuf format
SaveCompiledGraphToPb(phase_s);
// save the run graph func to MsPipeLine
SaveCompiledGraph(phase_s);
......@@ -557,20 +583,6 @@ void Pipeline::Run() {
std::string user_graph_file = GetFilePathName("ModelDigraph.dot");
MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file;
draw::DrawUserFuncGraph(user_graph_file, user_graph);
#ifdef ENABLE_DUMP_IR
std::string filename = GetFilePathName("ms_output.pb");
ChangeFileMode(filename, S_IRWXU);
std::ofstream ofs(filename);
if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
return;
}
ofs << GetFuncGraphProtoString(user_graph);
ofs.close();
// set file mode to read only by user
ChangeFileMode(filename, S_IRUSR);
#endif
}
MS_LOG(INFO) << "End";
}
......
......@@ -70,6 +70,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
~ExecutorPy();
void SaveCompiledGraph(const std::string &phase_s);
void SaveCompiledGraphToPb(const std::string &phase_s);
bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm);
bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm);
......
......@@ -158,7 +158,7 @@ void Profile::Print(void) {
std::ostringstream oss;
PrintProfile(oss, *ctx_ptr_->time_info_);
std::string text = oss.str();
// the length of text is too long to use MS_LOGINFO, use printf to print it
// here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace
(void)printf("%s", text.c_str());
(void)fflush(stdout);
}
......@@ -358,7 +358,7 @@ void MsProfile::Print() {
PrintTimeStat(oss, groups[i], prefix);
}
std::string text = oss.str();
// the length of text is too long to use MS_LOGINFO, use printf to print it
// here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace
(void)printf("\nTime group info:\n%s", text.c_str());
(void)fflush(stdout);
}
......
......@@ -328,7 +328,7 @@ class _Executor:
raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params)))
def compile(self, obj, *args, phase='predict', params=None):
def compile(self, obj, *args, phase='predict', params=None, do_convert=True):
"""
Compiles graph.
......@@ -337,6 +337,7 @@ class _Executor:
args (tuple): Function or cell input arguments.
phase (str): The name of compile phase. Default: 'predict'.
params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
Return:
Str, the full phase of the cell.
......@@ -368,7 +369,8 @@ class _Executor:
if graph is None:
logger.error("%r graph compile failed.", phase)
if not do_convert:
return phase, True
if not enable_debug_runtime or enable_ge:
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
......
......@@ -450,7 +450,7 @@ def export(net, *inputs, file_name, file_format='GEIR'):
_executor.export(net, file_name, file_format)
elif file_format == 'ONNX': # file_format is 'ONNX'
phase_name = 'export_onnx'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id)
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册