diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index 9a7cfc706d8b18768f086160ffdc93c3d5ef91ae..32aee56b0bde766d1f3504836ae35a6b22f8768e 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -78,7 +78,6 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .serialization import load, save from .tensor import Parameter, Tensor, tensor from .version import __version__ -from .utils import comp_graph_tools as cgtools _set_fork_exec_path_for_timed_func( sys.executable, diff --git a/imperative/python/megengine/utils/comp_graph_tools.py b/imperative/python/megengine/utils/comp_graph_tools.py index 2522e3f863045f0a4bdb58a5445895245bced93f..47cbb8fcff38395e8854c513ffb0e4232770c68b 100644 --- a/imperative/python/megengine/utils/comp_graph_tools.py +++ b/imperative/python/megengine/utils/comp_graph_tools.py @@ -15,6 +15,19 @@ from ..core._imperative_rt import OperatorNode, VarNode from ..core.tensor import megbrain_graph as G from ..core.tensor.raw_tensor import as_raw_tensor +__all__ = [ + "get_dep_vars", + "get_owner_opr_inputs", + "get_owner_opr_type", + "get_opr_type", + "graph_traversal", + "get_oprs_seq", + "replace_vars", + "replace_oprs", + "set_priority_to_id", + "load_and_inference", +] + def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: """ @@ -166,7 +179,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph def prune_reshape_oprs(outputs, oprs_seq, var2oprs): - def iterative_pruning(cur_opr, post_opr, marked_opr_ids): + def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited): useless = True for oup in cur_opr.outputs: if "workspace" not in oup.name: @@ -177,15 +190,20 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo if useless: marked_opr_ids.append(cur_opr.id) - for inp in cur_opr.inputs: - iterative_pruning(inp.owner, cur_opr, marked_opr_ids) + for opr in set([var.owner for var in cur_opr.inputs]): + if (opr.id, cur_opr.id) not in visited: + visited.add((opr.id, cur_opr.id)) + iterative_pruning(opr, cur_opr, marked_opr_ids, visited) reshape_vars = get_dep_vars(outputs, "Reshape") reshape_oprs = [var.owner for var in reshape_vars] marked_opr_ids = [] + visited = set() for reshape_opr in reshape_oprs: - iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids) + iterative_pruning( + reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited + ) # filter out all marked oprs return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) diff --git a/imperative/python/test/unit/test_cgtools.py b/imperative/python/test/unit/test_cgtools.py index 3f0f341e1f7f06d52e27939d093bbb75b626cc36..c90760b488f6664282db9c323bcaf8bd44592ec1 100644 --- a/imperative/python/test/unit/test_cgtools.py +++ b/imperative/python/test/unit/test_cgtools.py @@ -13,9 +13,10 @@ import pytest import megengine import megengine.functional as F import megengine.module as M -from megengine import cgtools +import megengine.utils.comp_graph_tools as cgtools from megengine.core.tensor import megbrain_graph as mgb_graph from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine.core.tensor.utils import astensor1d from megengine.jit import trace @@ -98,3 +99,38 @@ def test_load_refcnt(): graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) del graph varnode.owner + + +def test_get_opr_seq(): + class Net(M.Module): + def __init__(self): + super().__init__() + self.data = megengine.tensor( + np.random.random((1, 1, 4, 4)), dtype=np.float32 + ) + + def forward(self, input): + A = input.shape[0] + shape = astensor1d((A, A), self.data, dtype="int32", device=input.device) + x = F.reshape(self.data, shape) + o = input + x + return o + + net = Net() + input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32) + + @trace(symbolic=True, capture_as_const=True) + def func(inp, *, net=None): + return net(inp) + + func(input, net=net) + file = io.BytesIO() + func.dump(file, optimize_for_inference=False) + file.seek(0) + *_, outputs = mgb_graph.load_graph(file) + + seq_1 = cgtools.get_oprs_seq(outputs, True) + assert len(seq_1) == 5 + + seq_2 = cgtools.get_oprs_seq(outputs, False) + assert len(seq_2) == 6 diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 32ec58f52adb5eb72134ce0722ab3f2d83319b6c..e0ce660f50931dadf46968a4aac6beaf8a6c05e7 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -14,7 +14,8 @@ import pytest import megengine.core.tensor.megbrain_graph as G import megengine.functional as F -from megengine import cgtools, tensor +import megengine.utils.comp_graph_tools as cgtools +from megengine import tensor from megengine.core._trace_option import set_symbolic_shape from megengine.core.ops import builtin as ops from megengine.core.ops.builtin import Elemwise