提交 813be72c 编写于 作者: M Megvii Engine Team

feat(mge/utils): refactor GraphInference and add more options

GitOrigin-RevId: 44b96dbf3dbad8abad7900b2b4e0e82bf1c8314d
上级 d970b85d
...@@ -9,10 +9,9 @@ ...@@ -9,10 +9,9 @@
import collections import collections
import json import json
import os import os
import threading
import weakref import weakref
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Union from typing import Dict, List, Tuple, Union
import numpy as np import numpy as np
...@@ -22,7 +21,7 @@ from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode ...@@ -22,7 +21,7 @@ from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._imperative_rt.ops import BackwardGraph from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device from .._wrap import device as as_device
from ..ops.builtin import OpDef from ..ops.builtin import OpDef
from .core import OpBase, TensorBase from .core import TensorBase
def set_priority_to_id(dest_vars): def set_priority_to_id(dest_vars):
...@@ -284,9 +283,9 @@ def optimize_for_inference(dest_vars, **kwargs): ...@@ -284,9 +283,9 @@ def optimize_for_inference(dest_vars, **kwargs):
if kwargs: if kwargs:
raise ValueError("unknown options: %s" % list(kwargs)) raise ValueError("unknown options: %s" % list(kwargs))
dest_vars = [var._node for var in dest_vars] dest_vars = _unwrap(dest_vars)
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options)
return [VarNode(i) for i in res_vars] return _wrap(res_vars)
CompGraphDumpResult = collections.namedtuple( CompGraphDumpResult = collections.namedtuple(
...@@ -312,7 +311,7 @@ def dump_graph( ...@@ -312,7 +311,7 @@ def dump_graph(
keep_opr_priority: bool = False, keep_opr_priority: bool = False,
strip_info_file=None, strip_info_file=None,
append_json=False append_json=False
): ) -> Tuple[bytes, CompGraphDumpResult]:
""" """
serialize the computing graph of `output_vars` and get byte result. serialize the computing graph of `output_vars` and get byte result.
...@@ -347,22 +346,20 @@ def dump_graph( ...@@ -347,22 +346,20 @@ def dump_graph(
* ``params`` list of names of dumped params * ``params`` list of names of dumped params
* ``outputs`` names of output vars * ``outputs`` names of output vars
""" """
ov = []
if isinstance(output_vars, dict): if isinstance(output_vars, dict):
used_vars = set() used_vars = set()
for name, var in output_vars.items(): for name, var in output_vars.items():
assert isinstance(var, VarNode), "bad output var: {!r}".format(var)
assert var.id not in used_vars, ( assert var.id not in used_vars, (
"var name is associated with a var object, so we can not have " "var name is associated with a var object, so we can not have "
"two names given to the same var: {}".format(var) "two names given to the same var: {}".format(var)
) )
used_vars.add(var.id) used_vars.add(var.id)
var.name = name var.name = name
ov.append(var._node) output_vars = list(output_vars.values())
else: else:
for var in output_vars: output_vars = list(output_vars)
assert isinstance(var, VarNode), "bad output var: {!r}".format(var)
ov.append(var._node) ov = _unwrap(output_vars)
stat = [] stat = []
inputs = [] inputs = []
...@@ -413,7 +410,7 @@ CompGraphLoadResult = collections.namedtuple( ...@@ -413,7 +410,7 @@ CompGraphLoadResult = collections.namedtuple(
) )
def load_graph(fpath): def load_graph(fpath) -> CompGraphLoadResult:
""" """
Load a serialized computing graph from file. Load a serialized computing graph from file.
...@@ -471,8 +468,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode): ...@@ -471,8 +468,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode):
graph._make_const_for_backward, graph._make_const_for_backward,
args, args,
) )
outputs = [o._node if hasattr(o, "_node") else o for o in outputs] return _unwrap(outputs)
return outputs
set_cpp_apply_backward_varnode(apply_backward_varnode) set_cpp_apply_backward_varnode(apply_backward_varnode)
......
...@@ -7,12 +7,14 @@ ...@@ -7,12 +7,14 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Optional from typing import Dict, List, Tuple, Union
import numpy import numpy as np
from ..core import _imperative_rt from ..core import _imperative_rt
from ..core._imperative_rt import OperatorNode, VarNode from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt import OperatorNode as _OpNode
from ..core._imperative_rt import VarNode as _VarNode
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.megbrain_graph import set_priority_to_id from ..core.tensor.megbrain_graph import set_priority_to_id
from ..tensor import Tensor from ..tensor import Tensor
...@@ -31,7 +33,9 @@ __all__ = [ ...@@ -31,7 +33,9 @@ __all__ = [
] ]
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: def get_dep_vars(
var: Union[_VarNode, List[_VarNode]], var_type: Union[str, List[str]] = None
) -> List[_VarNode]:
""" """
Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var`` Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
depands on. If ``var_type`` is None, returns all types. depands on. If ``var_type`` is None, returns all types.
...@@ -39,7 +43,7 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: ...@@ -39,7 +43,7 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
outputs = [] outputs = []
memo = set() memo = set()
if isinstance(var, VarNode): if isinstance(var, _VarNode):
var = [var] var = [var]
if isinstance(var_type, str): if isinstance(var_type, str):
...@@ -61,14 +65,14 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: ...@@ -61,14 +65,14 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
return outputs return outputs
def get_owner_opr_inputs(var: VarNode) -> List[VarNode]: def get_owner_opr_inputs(var: _VarNode) -> List[_VarNode]:
""" """
Gets the inputs of owner opr of a variable. Gets the inputs of owner opr of a variable.
""" """
return var.owner.inputs return var.owner.inputs
def get_owner_opr_type(var: VarNode) -> str: def get_owner_opr_type(var: _VarNode) -> str:
""" """
Gets the type of owner opr of a variable. Gets the type of owner opr of a variable.
...@@ -76,15 +80,15 @@ def get_owner_opr_type(var: VarNode) -> str: ...@@ -76,15 +80,15 @@ def get_owner_opr_type(var: VarNode) -> str:
return var.owner.type return var.owner.type
def get_opr_type(opr: OperatorNode) -> str: def get_opr_type(opr: _OpNode) -> str:
""" """
Gets the type of an opr. Gets the type of an opr.
""" """
assert isinstance(opr, OperatorNode) assert isinstance(opr, _OpNode)
return opr.type return opr.type
def graph_traversal(outputs: VarNode): def graph_traversal(outputs: _VarNode):
""" """
Helper function to traverse the computing graph and return enough useful information. Helper function to traverse the computing graph and return enough useful information.
...@@ -142,8 +146,8 @@ def graph_traversal(outputs: VarNode): ...@@ -142,8 +146,8 @@ def graph_traversal(outputs: VarNode):
def get_oprs_seq( def get_oprs_seq(
outputs: List[VarNode], prune_reshape=False, prune_immtensor=True outputs: List[_VarNode], prune_reshape=False, prune_immtensor=True
) -> List[OperatorNode]: ) -> List[_OpNode]:
""" """
Gets oprs in some topological order for a dumped model. Gets oprs in some topological order for a dumped model.
...@@ -218,7 +222,9 @@ def get_oprs_seq( ...@@ -218,7 +222,9 @@ def get_oprs_seq(
return oprs_seq return oprs_seq
def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]: def replace_vars(
dst: List[_VarNode], varmap: Dict[_VarNode, _VarNode]
) -> List[_VarNode]:
""" """
Replaces vars in the graph. Replaces vars in the graph.
...@@ -232,21 +238,19 @@ def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]: ...@@ -232,21 +238,19 @@ def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
repl_src_vec = [] repl_src_vec = []
repl_dst_vec = [] repl_dst_vec = []
for i in dst: for i in dst:
assert isinstance(i, VarNode) assert isinstance(i, _VarNode)
dst_vec.append(i) dst_vec.append(i)
for i, j in getattr(varmap, "items", lambda: varmap)(): for i, j in getattr(varmap, "items", lambda: varmap)():
assert isinstance(i, VarNode) assert isinstance(i, _VarNode)
assert isinstance(j, VarNode) assert isinstance(j, _VarNode)
repl_src_vec.append(i) repl_src_vec.append(i)
repl_dst_vec.append(j) repl_dst_vec.append(j)
return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec) return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
def replace_oprs( def replace_oprs(dst: List[_VarNode], oprmap: Dict[_OpNode, _OpNode]) -> List[_VarNode]:
dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode]
) -> List[VarNode]:
""" """
Replaces operators in the graph. Replaces operators in the graph.
...@@ -260,65 +264,154 @@ def replace_oprs( ...@@ -260,65 +264,154 @@ def replace_oprs(
repl_src_vec = [] repl_src_vec = []
repl_dst_vec = [] repl_dst_vec = []
for i in dst: for i in dst:
assert isinstance(i, VarNode) assert isinstance(i, _VarNode)
dst_vec.append(i) dst_vec.append(i)
for i, j in getattr(oprmap, "items", lambda: oprmap)(): for i, j in getattr(oprmap, "items", lambda: oprmap)():
assert isinstance(i, OperatorNode) assert isinstance(i, _OpNode)
assert isinstance(j, OperatorNode) assert isinstance(j, _OpNode)
repl_src_vec.append(i) repl_src_vec.append(i)
repl_dst_vec.append(j) repl_dst_vec.append(j)
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
def find_vars_by_name(dst: List[_VarNode], names: List[str]) -> List[_VarNode]:
"""
Gets VarNode list by names in the graph.
:param dst: target vars representing the graph.
:param names: name list for target VarNode.
:return: results found by names.
"""
output_names = names.copy()
all_vars = get_dep_vars(dst) + dst
# use dict to keep outputs order the same as names.
output_dict = {}
for i in all_vars:
if i.name in output_names:
output_dict[i.name] = i
output_names.remove(i.name)
assert len(output_names) == 0, "Can not find varnode {} in this model".format(
output_names
)
return [output_dict[i] for i in names]
def convert_inputs(
dst: List[_VarNode], inputs: List[_VarNode] = None
) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
"""
Replaces ``Host2DeviceCopy`` with :class:`~.InputNode` in the graph
to :meth:`~.InputNode.set_value` and run.
:param dst: target vars representing the graph.
:param inputs: indicates which inputs to be replaced. All
inputs(``Host2DeiceCopy``) will be replaced if not specified.
:return: new vars that correspond to ``dst`` with all inputs
replaced, and new inputs dict.
"""
if inputs is None:
inputs = get_dep_vars(dst, "Host2DeviceCopy")
input_dict = OrderedDict()
replace_dict = {}
for inp in inputs:
inp_node = G.InputNode(
device=inp.comp_node, dtype=inp.dtype, shape=inp.shape, graph=inp.graph,
)
inp_node.name = inp.name
input_dict[inp.name] = inp_node
replace_dict[inp] = inp_node.outputs[0]
new_output_nodes = replace_vars(dst, replace_dict)
for old, new in zip(dst, new_output_nodes):
new.name = old.name
return new_output_nodes, input_dict
def convert_outputs(dst: List[_VarNode]) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
"""
Wraps ``dst`` with :class:`~.OutputNode` in the graph to get outputs
with :meth:`~.OutputNode.get_value`.
:param dst: target vars representing the graph.
:return: new vars that correspond to ``dst`` with all inputs
replaced, and outputs dict.
"""
output_dict = OrderedDict([(i.name, G.OutputNode(i)) for i in dst])
new_output_nodes = [i.outputs[0] for i in output_dict.values()]
return new_output_nodes, output_dict
def embed_inputs(
dst: List[_VarNode], data: List[np.ndarray], inputs: List[_VarNode] = None
) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
"""
Embeds ``data`` to the graph's inputs of ``dst``.
:param dst: target vars representing the graph.
:param data: data to be embeded.
:param inputs: indicates which inputs to be replaced. All
inputs(``Host2DeiceCopy``) will be replaced if not specified.
:return: new vars that correspond to ``dst`` with all inputs
replaced, and new inputs dict.
"""
if inputs is None:
inputs = get_dep_vars(dst, "Host2DeviceCopy")
assert len(data) == len(inputs)
input_dict = OrderedDict()
replace_dict = {}
for inp, d in zip(inputs, data):
new_inp = _imperative_rt.make_shared(inp.graph, Tensor(d)._dev_tensor())
new_inp.name = inp.name
input_dict[inp.name] = new_inp
replace_dict[inp] = new_inp
new_output_nodes = replace_vars(dst, replace_dict)
for old, new in zip(dst, new_output_nodes):
new.name = old.name
return new_output_nodes, input_dict
class GraphInference: class GraphInference:
""" """
Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph. Loads a serialized computing graph as a GraphInference object which can be used
The `GraphInference.run()` accepts a list `inp_args` or a dict `inp_dict` {input_name: input_value} as input and returns a dict {output_name: output_value}. to execute the computing graph.
:param file: could be file object or filename. :param file: could be file object or filename.
:param outputs: only compile the subgraph with outputs as its endpoints. :param outputs: only compile the subgraph with outputs as its endpoints.
""" """
def __init__(self, file, outputs: Optional[List[str]] = None): def __init__(
*_, output_nodes = G.load_graph(file) self,
file,
outputs: List[str] = None,
profiling: bool = False,
optimize_for_inference: bool = False,
**kwargs
):
self._graph, _, output_nodes = G.load_graph(file)
if outputs is not None: if outputs is not None:
output_name = outputs.copy() output_nodes = find_vars_by_name(output_nodes, outputs)
all_vars = get_dep_vars(output_nodes) + output_nodes self._origin_outputs = output_nodes
new_outputs = {}
for i in all_vars: # replace inputs with `InputNode`
if i.name in output_name: output_nodes, self._inp_dict = convert_inputs(output_nodes)
new_outputs[i.name] = i
output_name.remove(i.name) # replace outputs with `OutputNode`
assert ( output_nodes, self._oup_dict = convert_outputs(output_nodes)
len(output_name) == 0
), "Can not find varnode {} in this model".format(output_name) self._func = self._graph.compile(output_nodes)
output_nodes = [new_outputs[i] for i in outputs]
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy")
self._inp_dict = OrderedDict()
replace_dict = {}
for idx, i in enumerate(inputs):
inp_node = G.InputNode(
device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph
)
self._inp_dict[i.name] = inp_node
replace_dict[i] = inp_node.outputs[0]
new_output_nodes = replace_vars(output_nodes, replace_dict)
for old, new in zip(output_nodes, new_output_nodes):
new.name = old.name
self._out_dict = OrderedDict(
[(i.name, G.OutputNode(i)) for i in new_output_nodes]
)
new_out_list = [i.outputs[0] for i in self._out_dict.values()]
cg = new_out_list[0].graph
self._func = cg.compile(new_out_list)
def run( def run(
self, self, *inp_args: np.ndarray, inp_dict: Dict[str, np.ndarray] = None
*inp_args: numpy.ndarray, ) -> Dict[str, np.ndarray]:
inp_dict: Optional[Dict[str, numpy.ndarray]] = None """
): :param inp_args: list of input datas.
:param inp_dict: dict of named input datas.
:return: a dict {output_name: output_value}.
"""
assert len(inp_args) <= len( assert len(inp_args) <= len(
self._inp_dict self._inp_dict
), "This model expects {} inputs".format(len(self._inp_dict)) ), "This model expects {} inputs".format(len(self._inp_dict))
...@@ -335,8 +428,11 @@ class GraphInference: ...@@ -335,8 +428,11 @@ class GraphInference:
) )
for key in self._inp_dict: for key in self._inp_dict:
self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor()) self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())
self._func.execute() self._func.execute()
self._func.wait()
result = OrderedDict() result = OrderedDict()
for key in self._out_dict: for key in self._oup_dict:
result[key] = self._out_dict[key].get_value().numpy() result[key] = self._oup_dict[key].get_value().numpy()
return result return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册