diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index afd916113b41fa48a6e2cae854c2bf2c2b0ae3c6..9bd3aae5d67389e38adc1c976fb026729dbad3da 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -9,10 +9,9 @@ import collections import json import os -import threading import weakref -from concurrent.futures import Future, ThreadPoolExecutor -from typing import Dict, List, Union +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Tuple, Union import numpy as np @@ -22,7 +21,7 @@ from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode from .._imperative_rt.ops import BackwardGraph from .._wrap import device as as_device from ..ops.builtin import OpDef -from .core import OpBase, TensorBase +from .core import TensorBase def set_priority_to_id(dest_vars): @@ -284,9 +283,9 @@ def optimize_for_inference(dest_vars, **kwargs): if kwargs: raise ValueError("unknown options: %s" % list(kwargs)) - dest_vars = [var._node for var in dest_vars] + dest_vars = _unwrap(dest_vars) res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) - return [VarNode(i) for i in res_vars] + return _wrap(res_vars) CompGraphDumpResult = collections.namedtuple( @@ -312,7 +311,7 @@ def dump_graph( keep_opr_priority: bool = False, strip_info_file=None, append_json=False -): +) -> Tuple[bytes, CompGraphDumpResult]: """ serialize the computing graph of `output_vars` and get byte result. @@ -347,22 +346,20 @@ def dump_graph( * ``params`` list of names of dumped params * ``outputs`` names of output vars """ - ov = [] if isinstance(output_vars, dict): used_vars = set() for name, var in output_vars.items(): - assert isinstance(var, VarNode), "bad output var: {!r}".format(var) assert var.id not in used_vars, ( "var name is associated with a var object, so we can not have " "two names given to the same var: {}".format(var) ) used_vars.add(var.id) var.name = name - ov.append(var._node) + output_vars = list(output_vars.values()) else: - for var in output_vars: - assert isinstance(var, VarNode), "bad output var: {!r}".format(var) - ov.append(var._node) + output_vars = list(output_vars) + + ov = _unwrap(output_vars) stat = [] inputs = [] @@ -413,7 +410,7 @@ CompGraphLoadResult = collections.namedtuple( ) -def load_graph(fpath): +def load_graph(fpath) -> CompGraphLoadResult: """ Load a serialized computing graph from file. @@ -471,8 +468,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode): graph._make_const_for_backward, args, ) - outputs = [o._node if hasattr(o, "_node") else o for o in outputs] - return outputs + return _unwrap(outputs) set_cpp_apply_backward_varnode(apply_backward_varnode) diff --git a/imperative/python/megengine/utils/comp_graph_tools.py b/imperative/python/megengine/utils/comp_graph_tools.py index 80ff11138ba6a51a81d4b8c4b4f3703083b5e368..c7262e0a2a7ca907ba00692eaf149916d2fd4629 100644 --- a/imperative/python/megengine/utils/comp_graph_tools.py +++ b/imperative/python/megengine/utils/comp_graph_tools.py @@ -7,12 +7,14 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections from collections import OrderedDict -from typing import Dict, List, Optional +from typing import Dict, List, Tuple, Union -import numpy +import numpy as np from ..core import _imperative_rt -from ..core._imperative_rt import OperatorNode, VarNode +from ..core._imperative_rt import GraphProfiler +from ..core._imperative_rt import OperatorNode as _OpNode +from ..core._imperative_rt import VarNode as _VarNode from ..core.tensor import megbrain_graph as G from ..core.tensor.megbrain_graph import set_priority_to_id from ..tensor import Tensor @@ -31,7 +33,9 @@ __all__ = [ ] -def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: +def get_dep_vars( + var: Union[_VarNode, List[_VarNode]], var_type: Union[str, List[str]] = None +) -> List[_VarNode]: """ Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var`` depands on. If ``var_type`` is None, returns all types. @@ -39,7 +43,7 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: outputs = [] memo = set() - if isinstance(var, VarNode): + if isinstance(var, _VarNode): var = [var] if isinstance(var_type, str): @@ -61,14 +65,14 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: return outputs -def get_owner_opr_inputs(var: VarNode) -> List[VarNode]: +def get_owner_opr_inputs(var: _VarNode) -> List[_VarNode]: """ Gets the inputs of owner opr of a variable. """ return var.owner.inputs -def get_owner_opr_type(var: VarNode) -> str: +def get_owner_opr_type(var: _VarNode) -> str: """ Gets the type of owner opr of a variable. @@ -76,15 +80,15 @@ def get_owner_opr_type(var: VarNode) -> str: return var.owner.type -def get_opr_type(opr: OperatorNode) -> str: +def get_opr_type(opr: _OpNode) -> str: """ Gets the type of an opr. """ - assert isinstance(opr, OperatorNode) + assert isinstance(opr, _OpNode) return opr.type -def graph_traversal(outputs: VarNode): +def graph_traversal(outputs: _VarNode): """ Helper function to traverse the computing graph and return enough useful information. @@ -142,8 +146,8 @@ def graph_traversal(outputs: VarNode): def get_oprs_seq( - outputs: List[VarNode], prune_reshape=False, prune_immtensor=True -) -> List[OperatorNode]: + outputs: List[_VarNode], prune_reshape=False, prune_immtensor=True +) -> List[_OpNode]: """ Gets oprs in some topological order for a dumped model. @@ -218,7 +222,9 @@ def get_oprs_seq( return oprs_seq -def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]: +def replace_vars( + dst: List[_VarNode], varmap: Dict[_VarNode, _VarNode] +) -> List[_VarNode]: """ Replaces vars in the graph. @@ -232,21 +238,19 @@ def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]: repl_src_vec = [] repl_dst_vec = [] for i in dst: - assert isinstance(i, VarNode) + assert isinstance(i, _VarNode) dst_vec.append(i) for i, j in getattr(varmap, "items", lambda: varmap)(): - assert isinstance(i, VarNode) - assert isinstance(j, VarNode) + assert isinstance(i, _VarNode) + assert isinstance(j, _VarNode) repl_src_vec.append(i) repl_dst_vec.append(j) return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec) -def replace_oprs( - dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode] -) -> List[VarNode]: +def replace_oprs(dst: List[_VarNode], oprmap: Dict[_OpNode, _OpNode]) -> List[_VarNode]: """ Replaces operators in the graph. @@ -260,65 +264,154 @@ def replace_oprs( repl_src_vec = [] repl_dst_vec = [] for i in dst: - assert isinstance(i, VarNode) + assert isinstance(i, _VarNode) dst_vec.append(i) for i, j in getattr(oprmap, "items", lambda: oprmap)(): - assert isinstance(i, OperatorNode) - assert isinstance(j, OperatorNode) + assert isinstance(i, _OpNode) + assert isinstance(j, _OpNode) repl_src_vec.append(i) repl_dst_vec.append(j) return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) +def find_vars_by_name(dst: List[_VarNode], names: List[str]) -> List[_VarNode]: + """ + Gets VarNode list by names in the graph. + + :param dst: target vars representing the graph. + :param names: name list for target VarNode. + + :return: results found by names. + """ + output_names = names.copy() + all_vars = get_dep_vars(dst) + dst + # use dict to keep outputs order the same as names. + output_dict = {} + for i in all_vars: + if i.name in output_names: + output_dict[i.name] = i + output_names.remove(i.name) + assert len(output_names) == 0, "Can not find varnode {} in this model".format( + output_names + ) + return [output_dict[i] for i in names] + + +def convert_inputs( + dst: List[_VarNode], inputs: List[_VarNode] = None +) -> Tuple[List[_VarNode], Dict[str, _VarNode]]: + """ + Replaces ``Host2DeviceCopy`` with :class:`~.InputNode` in the graph + to :meth:`~.InputNode.set_value` and run. + + :param dst: target vars representing the graph. + :param inputs: indicates which inputs to be replaced. All + inputs(``Host2DeiceCopy``) will be replaced if not specified. + + :return: new vars that correspond to ``dst`` with all inputs + replaced, and new inputs dict. + """ + if inputs is None: + inputs = get_dep_vars(dst, "Host2DeviceCopy") + input_dict = OrderedDict() + replace_dict = {} + for inp in inputs: + inp_node = G.InputNode( + device=inp.comp_node, dtype=inp.dtype, shape=inp.shape, graph=inp.graph, + ) + inp_node.name = inp.name + input_dict[inp.name] = inp_node + replace_dict[inp] = inp_node.outputs[0] + new_output_nodes = replace_vars(dst, replace_dict) + for old, new in zip(dst, new_output_nodes): + new.name = old.name + return new_output_nodes, input_dict + + +def convert_outputs(dst: List[_VarNode]) -> Tuple[List[_VarNode], Dict[str, _VarNode]]: + """ + Wraps ``dst`` with :class:`~.OutputNode` in the graph to get outputs + with :meth:`~.OutputNode.get_value`. + + :param dst: target vars representing the graph. + + :return: new vars that correspond to ``dst`` with all inputs + replaced, and outputs dict. + """ + output_dict = OrderedDict([(i.name, G.OutputNode(i)) for i in dst]) + new_output_nodes = [i.outputs[0] for i in output_dict.values()] + return new_output_nodes, output_dict + + +def embed_inputs( + dst: List[_VarNode], data: List[np.ndarray], inputs: List[_VarNode] = None +) -> Tuple[List[_VarNode], Dict[str, _VarNode]]: + """ + Embeds ``data`` to the graph's inputs of ``dst``. + + :param dst: target vars representing the graph. + :param data: data to be embeded. + :param inputs: indicates which inputs to be replaced. All + inputs(``Host2DeiceCopy``) will be replaced if not specified. + :return: new vars that correspond to ``dst`` with all inputs + replaced, and new inputs dict. + """ + if inputs is None: + inputs = get_dep_vars(dst, "Host2DeviceCopy") + assert len(data) == len(inputs) + input_dict = OrderedDict() + replace_dict = {} + for inp, d in zip(inputs, data): + new_inp = _imperative_rt.make_shared(inp.graph, Tensor(d)._dev_tensor()) + new_inp.name = inp.name + input_dict[inp.name] = new_inp + replace_dict[inp] = new_inp + new_output_nodes = replace_vars(dst, replace_dict) + for old, new in zip(dst, new_output_nodes): + new.name = old.name + return new_output_nodes, input_dict + + class GraphInference: """ - Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph. - The `GraphInference.run()` accepts a list `inp_args` or a dict `inp_dict` {input_name: input_value} as input and returns a dict {output_name: output_value}. + Loads a serialized computing graph as a GraphInference object which can be used + to execute the computing graph. :param file: could be file object or filename. :param outputs: only compile the subgraph with outputs as its endpoints. """ - def __init__(self, file, outputs: Optional[List[str]] = None): - *_, output_nodes = G.load_graph(file) + def __init__( + self, + file, + outputs: List[str] = None, + profiling: bool = False, + optimize_for_inference: bool = False, + **kwargs + ): + self._graph, _, output_nodes = G.load_graph(file) if outputs is not None: - output_name = outputs.copy() - all_vars = get_dep_vars(output_nodes) + output_nodes - new_outputs = {} - for i in all_vars: - if i.name in output_name: - new_outputs[i.name] = i - output_name.remove(i.name) - assert ( - len(output_name) == 0 - ), "Can not find varnode {} in this model".format(output_name) - output_nodes = [new_outputs[i] for i in outputs] - inputs = get_dep_vars(output_nodes, "Host2DeviceCopy") - self._inp_dict = OrderedDict() - replace_dict = {} - for idx, i in enumerate(inputs): - inp_node = G.InputNode( - device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph - ) - self._inp_dict[i.name] = inp_node - replace_dict[i] = inp_node.outputs[0] - new_output_nodes = replace_vars(output_nodes, replace_dict) - for old, new in zip(output_nodes, new_output_nodes): - new.name = old.name - self._out_dict = OrderedDict( - [(i.name, G.OutputNode(i)) for i in new_output_nodes] - ) - new_out_list = [i.outputs[0] for i in self._out_dict.values()] - cg = new_out_list[0].graph - self._func = cg.compile(new_out_list) + output_nodes = find_vars_by_name(output_nodes, outputs) + self._origin_outputs = output_nodes + + # replace inputs with `InputNode` + output_nodes, self._inp_dict = convert_inputs(output_nodes) + + # replace outputs with `OutputNode` + output_nodes, self._oup_dict = convert_outputs(output_nodes) + + self._func = self._graph.compile(output_nodes) def run( - self, - *inp_args: numpy.ndarray, - inp_dict: Optional[Dict[str, numpy.ndarray]] = None - ): + self, *inp_args: np.ndarray, inp_dict: Dict[str, np.ndarray] = None + ) -> Dict[str, np.ndarray]: + """ + :param inp_args: list of input datas. + :param inp_dict: dict of named input datas. + :return: a dict {output_name: output_value}. + """ assert len(inp_args) <= len( self._inp_dict ), "This model expects {} inputs".format(len(self._inp_dict)) @@ -335,8 +428,11 @@ class GraphInference: ) for key in self._inp_dict: self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor()) + self._func.execute() + self._func.wait() + result = OrderedDict() - for key in self._out_dict: - result[key] = self._out_dict[key].get_value().numpy() + for key in self._oup_dict: + result[key] = self._oup_dict[key].get_value().numpy() return result