提交 b1baee60 编写于 作者: M Megvii Engine Team

feat(imperative/utils): add optimize-for-inference interface for opgraph

GitOrigin-RevId: 9f93f821905dc05e3968247129920a0a1d43712f
上级 86598c82
......@@ -11,7 +11,7 @@ import fnmatch
import itertools
import re
from collections import OrderedDict
from typing import Dict, List
from typing import Dict, List, Sequence
import numpy as np
......@@ -87,42 +87,11 @@ class Network:
for o in opr.outputs:
self.all_vars_map[o.var.id] = o
def dump(
self,
file,
*,
keep_var_name: int = 1,
keep_opr_name: bool = False,
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
append_json=False,
optimize_for_inference=True,
append=False,
**kwargs
):
"""
Serializes graph to file.
def optimize_for_inference(self, dest_vars, **kwargs):
r"""
Applies optimize_for_inference pass for operator graph.
:param file: output file, could be file object or filename.
:param append: whether output is appended to ``file``.
Only works when ``file`` is str.
:param keep_var_name: level for keeping variable names:
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
:param keep_opr_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators
:param strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file``
:param append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
:param dest_vars: list of output vars in the operator graph
:Keyword Arguments:
......@@ -164,6 +133,55 @@ class Network:
inference)
"""
if not isinstance(dest_vars, Sequence):
dest_vars = [dest_vars]
dest_vars = list(G.VarNode(var.var) for var in dest_vars)
new_vars = G.optimize_for_inference(dest_vars, **kwargs)
return list(self._get_var(var) for var in new_vars)
def dump(
self,
file,
*,
keep_var_name: int = 1,
keep_opr_name: bool = False,
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
append_json=False,
optimize_for_inference=True,
append=False,
**kwargs
):
"""
Serializes graph to file.
:param file: output file, could be file object or filename.
:param append: whether output is appended to ``file``.
Only works when ``file`` is str.
:param keep_var_name: level for keeping variable names:
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
:param keep_opr_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators
:param strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file``
:param append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
:Keyword Arguments:
See also :py:meth:`optimize_for_inference`.
"""
self._compile()
out = [G.VarNode(var.var) for var in self.output_vars]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册