diff --git a/python_module/megengine/_internal/comp_graph_tools.py b/python_module/megengine/_internal/comp_graph_tools.py index 19223d072cff7b3a0deae57ee6a397bb22fa1272..5ef32bd869627ca954cddd3115e7956eec06c397 100644 --- a/python_module/megengine/_internal/comp_graph_tools.py +++ b/python_module/megengine/_internal/comp_graph_tools.py @@ -67,6 +67,145 @@ def get_type(var): return _mgb._get_owner_opr_type(var) +def get_opr_type(opr): + """get the type of a opr + + :type var: :class:`.Operator` + :rtype: ``str`` + """ + assert isinstance(opr, _mgb.Operator) + return _mgb._get_opr_type(opr) + + +def graph_traversal(outputs): + """helper function to traverse the computing graph and reeturn enough useful information + + :param outputs: model outputs + :type outputs: :class:`.Symbolvar` + :return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree) + WHERE + map_oprs is dict from opr_id to actual opr + map_vars is dict from var_id to actual var + var2oprs is dict from var to dest oprs along with index + opr2receivers is dict from current opr to next opr + indegree2opr is dict from in_degree to opr in computing graph + opr2indegree is dict from opr in computing graph to in_degree + + (indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function + """ + # meta information for comp graph + map_oprs = collections.defaultdict(set) + map_vars = collections.defaultdict(set) + + var2oprs = collections.defaultdict(list) + opr2receivers = collections.defaultdict(list) + + queue = list(map(lambda x: x.owner_opr, outputs)) + visited = set(map(lambda x: x.id, queue)) + + # iterate through whole comp_graph, fill in meta information + indegree2opr = collections.defaultdict(set) + opr2indegree = {} + + idx = 0 + while idx < len(queue): + cur_opr = queue[idx] + map_oprs[cur_opr.id] = cur_opr + + idx += 1 + + indegree = 0 + for var_idx, var in enumerate(cur_opr.inputs): + map_vars[var.id] = var + var2oprs[var.id].append((cur_opr.id, var_idx)) + + pre_opr = var.owner_opr + + if pre_opr.id not in visited: + visited.add(pre_opr.id) + queue.append(pre_opr) + + indegree += 1 + opr2receivers[pre_opr.id].append(cur_opr.id) + + indegree2opr[indegree].add(cur_opr.id) + opr2indegree[cur_opr.id] = indegree + + return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree + + +def get_oprs_seq(outputs, prune_reshape=False): + """get oprs in some topological order for a dumped model + + :param outputs: model outputs + :param prune_reshape: whether to prune the operators useless during inference + :return: opr list with some correct execution order + """ + + def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree): + # generate an execution order with topological sort algorithm + oprs_seq = [] + nr_remain = len(map_oprs) + while indegree2opr[0]: + opr_id = indegree2opr[0].pop() + opr = map_oprs[opr_id] + nr_remain -= 1 + + # skip const value generation operator + if get_opr_type(opr) != "ImmutableTensor": + oprs_seq.append(opr) + + for post_id in opr2receivers[opr_id]: + indegree = opr2indegree[post_id] + indegree2opr[indegree].remove(post_id) + + indegree -= 1 + indegree2opr[indegree].add(post_id) + opr2indegree[post_id] = indegree + + assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( + nr_remain + ) + return oprs_seq + + # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor + # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph + def prune_reshape_oprs(outputs, oprs_seq, var2oprs): + def iterative_pruning(cur_opr, post_opr, marked_opr_ids): + useless = True + for oup in cur_opr.outputs: + if "workspace" not in oup.name: + var_idx = post_opr.inputs.index(oup) + var2oprs[oup.id].remove((post_opr.id, var_idx)) + useless = useless and (len(var2oprs[oup.id]) == 0) + + if useless: + marked_opr_ids.append(cur_opr.id) + + for inp in cur_opr.inputs: + iterative_pruning(inp.owner_opr, cur_opr, marked_opr_ids) + + reshape_vars = get_dep_vars(outputs, "Reshape") + reshape_oprs = [var.owner_opr for var in reshape_vars] + + marked_opr_ids = [] + for reshape_opr in reshape_oprs: + iterative_pruning( + reshape_opr.inputs[1].owner_opr, reshape_opr, marked_opr_ids + ) + + # filter out all marked oprs + return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) + + map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( + outputs + ) + oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) + if prune_reshape is True: + oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) + return oprs_seq + + def replace_vars(dst, varmap): """replace vars in the graph diff --git a/python_module/src/swig/comp_graph_tools.i b/python_module/src/swig/comp_graph_tools.i index 4ad601db7aaecd9ef17507bf6eff3cc250ca709f..26253fa667f347f38a7cbbb5485ba8d5db9779ad 100644 --- a/python_module/src/swig/comp_graph_tools.i +++ b/python_module/src/swig/comp_graph_tools.i @@ -10,6 +10,10 @@ return var.node()->owner_opr()->dyn_typeinfo()->name; } + std::string _get_opr_type(Operator opr) { + return opr.node()->dyn_typeinfo()->name; + } + SymbolVarArray _replace_vars(const SymbolVarArray& repl_src, const SymbolVarArray& repl_dst, const SymbolVarArray& vars) { diff --git a/python_module/test/unit/jit/test_jit.py b/python_module/test/unit/jit/test_jit.py index d2b6eecbf33160a9d39b90eb7e97fdb3f9799b60..b14b1dfc5eec79c7a489957eeff23a638f56c282 100644 --- a/python_module/test/unit/jit/test_jit.py +++ b/python_module/test/unit/jit/test_jit.py @@ -15,6 +15,7 @@ import pytest import megengine as mge import megengine._internal as mgb +import megengine.functional as F import megengine.module as M from megengine import functional as F from megengine import jit, tensor @@ -148,6 +149,49 @@ def test_dump_volatile(): assert mgb.cgtools.get_type(mgb.cgtools.get_inputs(out)[1]) == "SharedDeviceTensor" +def test_graph_traversal(): + net = M.Conv2d(3, 4, 3, 1, 1, groups=1, bias=False) + net.eval() + + @jit.trace(symbolic=True) + def fun(data): + return net(data) + + data = np.random.random([1, 3, 224, 224]).astype(np.float32) + fun.trace(data) + + with mkstemp() as out: + fun.dump(out) + *_, outputs = mgb.load_comp_graph_from_file(out) + + _, map_vars, var2oprs, *_ = mgb.cgtools.graph_traversal(outputs) + input_var = map_vars[1] + _, var_idx = var2oprs[input_var.id][0] + + assert var_idx == 0 + + +def test_network_visitor(): + @jit.trace(symbolic=True) + def f(x): + # this line will produce shape_of, subtensor and concat op + # after pruning, they will be deleted + target_shape = (x.shape[0], -1) + + return x.reshape(*target_shape) + + f.trace(tensor(np.random.random([2, 3, 4, 5]).astype(np.float32))) + + with mkstemp() as out: + f.dump(out) + *_, outputs = mgb.load_comp_graph_from_file(out) + + all_oprs = mgb.cgtools.get_oprs_seq(outputs) + pruned_oprs = mgb.cgtools.get_oprs_seq(outputs, prune_reshape=True) + + assert len(all_oprs) == len(pruned_oprs) + 3 + + def test_shape_tracing(): for symbolic in [False, True]: