diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 54b913d920b10c0eba2a6040f4dbeb082ea61c0c..28bfc7e14db8c88eafe71282cf6bfbf202f2d9a1 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -893,6 +893,10 @@ class trace: if isinstance(file, str): permission = "wb" if append == False else "ab" file = open(file, permission) + + if keep_opr_priority: + graph._set_priority_to_id(dest_vars) + dump_content, dump_info = G.dump_graph( dest_vars, keep_var_name=keep_var_name, diff --git a/imperative/python/megengine/utils/comp_graph_tools.py b/imperative/python/megengine/utils/comp_graph_tools.py index 379b116745bd7143eedc5cdf1bd85d46ae111262..840f1e9eedec0b1ce4d501f6b6284f4e4623fbb2 100644 --- a/imperative/python/megengine/utils/comp_graph_tools.py +++ b/imperative/python/megengine/utils/comp_graph_tools.py @@ -6,6 +6,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections +import heapq from collections import OrderedDict from typing import Dict, List, Tuple, Union @@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str: return opr.type +class _OprStableOrderHeapq: + """heap implementation for operator comparison in stable order""" + + _list = None + _extra_priority = None + _used_id_name_pairs = None + + def __init__(self, extra_priority): + assert isinstance(extra_priority, collections.Callable) + self._list = [] + self._extra_priority = extra_priority + self._used_id_name_pairs = {} + + def pop_min(self): + return heapq.heappop(self._list)[-1] + + def add(self, opr): + # named as add to mimic set() interface + + id_ = opr.id + name = opr.name + + other = self._used_id_name_pairs.setdefault((id_, name), opr) + if other is not opr: + raise RuntimeError( + "duplicated (id, name) pair: opr0={} opr1={}".format(other, opr) + ) + + item = self._extra_priority(opr) + (id_, name, opr) + heapq.heappush(self._list, item) + + def __bool__(self): + return bool(self._list) + + def graph_traversal(outputs: _VarNode): """ Helper function to traverse the computing graph and return enough useful information. @@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode): var2oprs = collections.defaultdict(list) opr2receivers = collections.defaultdict(list) - - queue = list(set(map(lambda x: x.owner, outputs))) + queue = [] + [queue.append(o) for o in [x.owner for x in outputs] if o not in queue] visited = set(map(lambda x: x.id, queue)) # iterate through whole comp_graph, fill in meta information indegree2opr = collections.defaultdict(set) + indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,)) opr2indegree = {} idx = 0 @@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode): indegree += 1 opr2receivers[pre_opr.id].append(cur_opr.id) - - indegree2opr[indegree].add(cur_opr.id) + opr = cur_opr if indegree == 0 else cur_opr.id + indegree2opr[indegree].add(opr) opr2indegree[cur_opr.id] = indegree return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree @@ -162,8 +199,8 @@ def get_oprs_seq( oprs_seq = [] nr_remain = len(map_oprs) while indegree2opr[0]: - opr_id = indegree2opr[0].pop() - opr = map_oprs[opr_id] + opr = indegree2opr[0].pop_min() + opr_id = opr.id nr_remain -= 1 if opr.type != "ImmutableTensor" or not prune_immtensor: oprs_seq.append(opr) @@ -173,7 +210,10 @@ def get_oprs_seq( indegree2opr[indegree].remove(post_id) indegree -= 1 - indegree2opr[indegree].add(post_id) + if indegree == 0: + indegree2opr[indegree].add(map_oprs[post_id]) + else: + indegree2opr[indegree].add(post_id) opr2indegree[post_id] = indegree assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( @@ -213,10 +253,34 @@ def get_oprs_seq( # filter out all marked oprs return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) + # adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs. + def reorder_oprs_seq(oprs): + rst = [] + param_or_data_provider_oprs = [] + other_oprs = [] + + for o in oprs: + if o.type in ["ImmutableTensor", "Host2DeviceCopy"]: + param_or_data_provider_oprs.append(o) + else: + other_oprs.append(o) + + for o in other_oprs: + for inp in o.inputs: + if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]: + if inp.owner in param_or_data_provider_oprs: + rst.append(inp.owner) + param_or_data_provider_oprs.remove(inp.owner) + rst.append(o) + rst = rst + param_or_data_provider_oprs + assert len(rst) == len(oprs) + return rst + map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( outputs ) oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) + oprs_seq = reorder_oprs_seq(oprs_seq) if prune_reshape is True: oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) return oprs_seq diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index 4bb167e652884f1c9b84e2f1eacfb95083c1a3fc..2cf62d3282aba3b9a9d39a0ecde54b68fef0a354 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -241,6 +241,7 @@ class Network: if optimize_for_inference: metadata.optimize_options = optimize_options + G.set_priority_to_id([o._node if isinstance(o, G.VarNode) else o for o in out]) dump_content, _ = G.dump_graph( out, keep_var_name=keep_var_name, @@ -353,7 +354,7 @@ class Network: ) shp[0] = batchsize i.shape = tuple(shp) - + self._compile() assert prev_batchsize is not None, "no data provider found" assert not blacklist, "unused items in blacklist: {}".format(blacklist) @@ -363,7 +364,6 @@ class Network: :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. """ if not all([var.owner for var in repl_dict.values()]): - print(repl_dict.values()) self.add_dep_oprs(*list(repl_dict.values())) for var in self.all_vars: if var in repl_dict: @@ -373,6 +373,7 @@ class Network: owner.outputs[idx] = var var.__dict__.update(repl_var.__dict__) var.var = repl_var.var + self._compile() def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): """ @@ -384,11 +385,11 @@ class Network: assert len(opr.outputs) == len( repl_dict[opr].outputs ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) - repl_dict[opr].outputs = opr.outputs for ind, var in enumerate(opr.outputs): var.owner = repl_dict[opr] var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) var.var = repl_dict[opr].outputs[ind].var + self._compile() def get_opr_by_type(self, oprcls, unique=True): assert issubclass(oprcls, OpNode) diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index 314c73e8feb6be7e052eca3bbd1782830581b4bb..7357d1749e29dc2b92ec7e12b8d53f51744c072f 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): def dtype(self): return self.var.dtype if self.var else None + @property + def ndim(self): + return super().ndim + def __bool__(self): return False @@ -134,7 +138,18 @@ class OpNode(NetworkNode): self.outputs = [] self.params = {} self._opr = None # mgb opnode - self.id = id(self) + + @property + def id(self): + if self._opr is not None: + return self._opr.id + return id(self) + + @property + def priority(self): + if self._opr is not None: + return self._opr.priority + return 0 @classmethod def load(cls, opr): @@ -144,16 +159,21 @@ class OpNode(NetworkNode): obj._opr = opr return obj - def compile(self, graph=None): - op = self.opdef(**self.params) - args = [i.var for i in self.inputs] - outputs = rt.invoke_op(op, args) - assert len(outputs) == len(self.outputs) - self._opr = outputs[0].owner - for i in range(len(self.outputs)): - self.outputs[i].var = outputs[i] - self.outputs[i].var.name = self.outputs[i].name - assert self.outputs[i].owner is self + def compile(self): + if ( + self._opr is None + or len(self._opr.inputs) != len(self.inputs) + or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)]) + ): + op = self.opdef(**self.params) + args = [i.var for i in self.inputs] + outputs = rt.invoke_op(op, args) + assert len(outputs) == len(self.outputs) + self._opr = outputs[0].owner + for i in range(len(self.outputs)): + self.outputs[i].var = outputs[i] + self.outputs[i].var.name = self.outputs[i].name + assert self.outputs[i].owner is self def add_inp_var(self, x): self.inputs.append(x) @@ -197,11 +217,17 @@ class Host2DeviceCopy(OpNode): return self def compile(self, graph): - outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) - self._opr = outputs.owner - if len(self.outputs) == 0: - self.outputs.append(VarNode(owner_opr=self, name=self.name)) - self.outputs[0].var = outputs + if ( + self._opr is None + or self._opr.outputs[0].comp_node != self.device + or self._opr.outputs[0].shape != self.shape + or self._opr.outputs[0].dtype != self.dtype + ): + outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) + self._opr = outputs.owner + if len(self.outputs) == 0: + self.outputs.append(VarNode(owner_opr=self, name=self.name)) + self.outputs[0].var = outputs assert self.outputs[0].owner is self diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 44736685daab2bccde612839bdbfb6ee0cda07ea..ab5ba647ad0265b7458f84ffd1eebaf0a15acaa2 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -192,7 +192,14 @@ void init_graph_rt(py::module m) { }) .def("__repr__", [](cg::OperatorNodeBase* opr){ return "Opr:" + opr->name(); - }); + }) + .def_property("priority", + [](cg::OperatorNodeBase* opr) { + return opr->node_prop().attribute().priority; + }, + [](cg::OperatorNodeBase* opr, int priority) { + opr->node_prop().attribute().priority = priority; + }); py::class_(m, "AsyncExecutable") .def("execute", &cg::AsyncExecutable::execute, py::call_guard()) diff --git a/imperative/python/test/unit/utils/test_cgtools.py b/imperative/python/test/unit/utils/test_cgtools.py index c8b51daa6895231bbbfc8ce3c03eafa4489db8c9..8289da3e11464dc8307ece5ea628afc83444cac4 100644 --- a/imperative/python/test/unit/utils/test_cgtools.py +++ b/imperative/python/test/unit/utils/test_cgtools.py @@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph from megengine.core.tensor.megbrain_graph import apply_normal_varnode from megengine.core.tensor.utils import astensor1d from megengine.jit import trace +from megengine.utils.network import Network def make_dev_tensor(value, dtype=None, device=None): @@ -143,6 +144,46 @@ def test_get_opr_seq(): assert len(seq_2) == 6 +def test_topological_sort(): + @trace(symbolic=True, capture_as_const=True) + def func(x, y): + a = x + y + a1 = F.relu(a) + a2 = F.abs(a) + a3 = F.ceil(a) * 2 + a4 = F.floor(a) + r = a1 - a2 + r1 = a3 / a4 + return r, r1 + + file = io.BytesIO() + func(megengine.tensor(1.0), megengine.tensor(2.0)) + func.dump( + file, optimize_for_inference=False, keep_opr_name=True, keep_opr_priority=True + ) + file.seek(0) + g = Network.load(file) + oprseq1 = g.all_oprs + gt = [ + "Host2DeviceCopy", + "Host2DeviceCopy", + "ADD", + "RELU", + "ABS", + "CEIL", + "ImmutableTensor", + "MUL", + "FLOOR", + "SUB", + "TRUE_DIV", + ] + for op, mode in zip(oprseq1, gt): + if op.type == "Elemwise": + assert op.params["mode"] == mode + else: + assert op.type == mode + + def test_graph_function(): class Net(M.Module): def forward(self, a, b):