提交 2d42455f 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/utils): fix toposort to get definition order

GitOrigin-RevId: 47a26dd6dda31d349f7439c64a72027e6a9a7391
上级 0c97b2a3
...@@ -893,6 +893,10 @@ class trace: ...@@ -893,6 +893,10 @@ class trace:
if isinstance(file, str): if isinstance(file, str):
permission = "wb" if append == False else "ab" permission = "wb" if append == False else "ab"
file = open(file, permission) file = open(file, permission)
if keep_opr_priority:
graph._set_priority_to_id(dest_vars)
dump_content, dump_info = G.dump_graph( dump_content, dump_info = G.dump_graph(
dest_vars, dest_vars,
keep_var_name=keep_var_name, keep_var_name=keep_var_name,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import heapq
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
...@@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str: ...@@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str:
return opr.type return opr.type
class _OprStableOrderHeapq:
"""heap implementation for operator comparison in stable order"""
_list = None
_extra_priority = None
_used_id_name_pairs = None
def __init__(self, extra_priority):
assert isinstance(extra_priority, collections.Callable)
self._list = []
self._extra_priority = extra_priority
self._used_id_name_pairs = {}
def pop_min(self):
return heapq.heappop(self._list)[-1]
def add(self, opr):
# named as add to mimic set() interface
id_ = opr.id
name = opr.name
other = self._used_id_name_pairs.setdefault((id_, name), opr)
if other is not opr:
raise RuntimeError(
"duplicated (id, name) pair: opr0={} opr1={}".format(other, opr)
)
item = self._extra_priority(opr) + (id_, name, opr)
heapq.heappush(self._list, item)
def __bool__(self):
return bool(self._list)
def graph_traversal(outputs: _VarNode): def graph_traversal(outputs: _VarNode):
""" """
Helper function to traverse the computing graph and return enough useful information. Helper function to traverse the computing graph and return enough useful information.
...@@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode): ...@@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode):
var2oprs = collections.defaultdict(list) var2oprs = collections.defaultdict(list)
opr2receivers = collections.defaultdict(list) opr2receivers = collections.defaultdict(list)
queue = []
queue = list(set(map(lambda x: x.owner, outputs))) [queue.append(o) for o in [x.owner for x in outputs] if o not in queue]
visited = set(map(lambda x: x.id, queue)) visited = set(map(lambda x: x.id, queue))
# iterate through whole comp_graph, fill in meta information # iterate through whole comp_graph, fill in meta information
indegree2opr = collections.defaultdict(set) indegree2opr = collections.defaultdict(set)
indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,))
opr2indegree = {} opr2indegree = {}
idx = 0 idx = 0
...@@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode): ...@@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode):
indegree += 1 indegree += 1
opr2receivers[pre_opr.id].append(cur_opr.id) opr2receivers[pre_opr.id].append(cur_opr.id)
opr = cur_opr if indegree == 0 else cur_opr.id
indegree2opr[indegree].add(cur_opr.id) indegree2opr[indegree].add(opr)
opr2indegree[cur_opr.id] = indegree opr2indegree[cur_opr.id] = indegree
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
...@@ -162,8 +199,8 @@ def get_oprs_seq( ...@@ -162,8 +199,8 @@ def get_oprs_seq(
oprs_seq = [] oprs_seq = []
nr_remain = len(map_oprs) nr_remain = len(map_oprs)
while indegree2opr[0]: while indegree2opr[0]:
opr_id = indegree2opr[0].pop() opr = indegree2opr[0].pop_min()
opr = map_oprs[opr_id] opr_id = opr.id
nr_remain -= 1 nr_remain -= 1
if opr.type != "ImmutableTensor" or not prune_immtensor: if opr.type != "ImmutableTensor" or not prune_immtensor:
oprs_seq.append(opr) oprs_seq.append(opr)
...@@ -173,6 +210,9 @@ def get_oprs_seq( ...@@ -173,6 +210,9 @@ def get_oprs_seq(
indegree2opr[indegree].remove(post_id) indegree2opr[indegree].remove(post_id)
indegree -= 1 indegree -= 1
if indegree == 0:
indegree2opr[indegree].add(map_oprs[post_id])
else:
indegree2opr[indegree].add(post_id) indegree2opr[indegree].add(post_id)
opr2indegree[post_id] = indegree opr2indegree[post_id] = indegree
...@@ -213,10 +253,34 @@ def get_oprs_seq( ...@@ -213,10 +253,34 @@ def get_oprs_seq(
# filter out all marked oprs # filter out all marked oprs
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
# adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs.
def reorder_oprs_seq(oprs):
rst = []
param_or_data_provider_oprs = []
other_oprs = []
for o in oprs:
if o.type in ["ImmutableTensor", "Host2DeviceCopy"]:
param_or_data_provider_oprs.append(o)
else:
other_oprs.append(o)
for o in other_oprs:
for inp in o.inputs:
if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]:
if inp.owner in param_or_data_provider_oprs:
rst.append(inp.owner)
param_or_data_provider_oprs.remove(inp.owner)
rst.append(o)
rst = rst + param_or_data_provider_oprs
assert len(rst) == len(oprs)
return rst
map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
outputs outputs
) )
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
oprs_seq = reorder_oprs_seq(oprs_seq)
if prune_reshape is True: if prune_reshape is True:
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
return oprs_seq return oprs_seq
......
...@@ -241,6 +241,7 @@ class Network: ...@@ -241,6 +241,7 @@ class Network:
if optimize_for_inference: if optimize_for_inference:
metadata.optimize_options = optimize_options metadata.optimize_options = optimize_options
G.set_priority_to_id([o._node if isinstance(o, G.VarNode) else o for o in out])
dump_content, _ = G.dump_graph( dump_content, _ = G.dump_graph(
out, out,
keep_var_name=keep_var_name, keep_var_name=keep_var_name,
...@@ -353,7 +354,7 @@ class Network: ...@@ -353,7 +354,7 @@ class Network:
) )
shp[0] = batchsize shp[0] = batchsize
i.shape = tuple(shp) i.shape = tuple(shp)
self._compile()
assert prev_batchsize is not None, "no data provider found" assert prev_batchsize is not None, "no data provider found"
assert not blacklist, "unused items in blacklist: {}".format(blacklist) assert not blacklist, "unused items in blacklist: {}".format(blacklist)
...@@ -363,7 +364,6 @@ class Network: ...@@ -363,7 +364,6 @@ class Network:
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars.
""" """
if not all([var.owner for var in repl_dict.values()]): if not all([var.owner for var in repl_dict.values()]):
print(repl_dict.values())
self.add_dep_oprs(*list(repl_dict.values())) self.add_dep_oprs(*list(repl_dict.values()))
for var in self.all_vars: for var in self.all_vars:
if var in repl_dict: if var in repl_dict:
...@@ -373,6 +373,7 @@ class Network: ...@@ -373,6 +373,7 @@ class Network:
owner.outputs[idx] = var owner.outputs[idx] = var
var.__dict__.update(repl_var.__dict__) var.__dict__.update(repl_var.__dict__)
var.var = repl_var.var var.var = repl_var.var
self._compile()
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
""" """
...@@ -384,11 +385,11 @@ class Network: ...@@ -384,11 +385,11 @@ class Network:
assert len(opr.outputs) == len( assert len(opr.outputs) == len(
repl_dict[opr].outputs repl_dict[opr].outputs
), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr]))
repl_dict[opr].outputs = opr.outputs
for ind, var in enumerate(opr.outputs): for ind, var in enumerate(opr.outputs):
var.owner = repl_dict[opr] var.owner = repl_dict[opr]
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
var.var = repl_dict[opr].outputs[ind].var var.var = repl_dict[opr].outputs[ind].var
self._compile()
def get_opr_by_type(self, oprcls, unique=True): def get_opr_by_type(self, oprcls, unique=True):
assert issubclass(oprcls, OpNode) assert issubclass(oprcls, OpNode)
......
...@@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): ...@@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def dtype(self): def dtype(self):
return self.var.dtype if self.var else None return self.var.dtype if self.var else None
@property
def ndim(self):
return super().ndim
def __bool__(self): def __bool__(self):
return False return False
...@@ -134,7 +138,18 @@ class OpNode(NetworkNode): ...@@ -134,7 +138,18 @@ class OpNode(NetworkNode):
self.outputs = [] self.outputs = []
self.params = {} self.params = {}
self._opr = None # mgb opnode self._opr = None # mgb opnode
self.id = id(self)
@property
def id(self):
if self._opr is not None:
return self._opr.id
return id(self)
@property
def priority(self):
if self._opr is not None:
return self._opr.priority
return 0
@classmethod @classmethod
def load(cls, opr): def load(cls, opr):
...@@ -144,7 +159,12 @@ class OpNode(NetworkNode): ...@@ -144,7 +159,12 @@ class OpNode(NetworkNode):
obj._opr = opr obj._opr = opr
return obj return obj
def compile(self, graph=None): def compile(self):
if (
self._opr is None
or len(self._opr.inputs) != len(self.inputs)
or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)])
):
op = self.opdef(**self.params) op = self.opdef(**self.params)
args = [i.var for i in self.inputs] args = [i.var for i in self.inputs]
outputs = rt.invoke_op(op, args) outputs = rt.invoke_op(op, args)
...@@ -197,6 +217,12 @@ class Host2DeviceCopy(OpNode): ...@@ -197,6 +217,12 @@ class Host2DeviceCopy(OpNode):
return self return self
def compile(self, graph): def compile(self, graph):
if (
self._opr is None
or self._opr.outputs[0].comp_node != self.device
or self._opr.outputs[0].shape != self.shape
or self._opr.outputs[0].dtype != self.dtype
):
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
self._opr = outputs.owner self._opr = outputs.owner
if len(self.outputs) == 0: if len(self.outputs) == 0:
......
...@@ -192,6 +192,13 @@ void init_graph_rt(py::module m) { ...@@ -192,6 +192,13 @@ void init_graph_rt(py::module m) {
}) })
.def("__repr__", [](cg::OperatorNodeBase* opr){ .def("__repr__", [](cg::OperatorNodeBase* opr){
return "Opr:" + opr->name(); return "Opr:" + opr->name();
})
.def_property("priority",
[](cg::OperatorNodeBase* opr) {
return opr->node_prop().attribute().priority;
},
[](cg::OperatorNodeBase* opr, int priority) {
opr->node_prop().attribute().priority = priority;
}); });
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
......
...@@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph ...@@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.core.tensor.megbrain_graph import apply_normal_varnode from megengine.core.tensor.megbrain_graph import apply_normal_varnode
from megengine.core.tensor.utils import astensor1d from megengine.core.tensor.utils import astensor1d
from megengine.jit import trace from megengine.jit import trace
from megengine.utils.network import Network
def make_dev_tensor(value, dtype=None, device=None): def make_dev_tensor(value, dtype=None, device=None):
...@@ -143,6 +144,46 @@ def test_get_opr_seq(): ...@@ -143,6 +144,46 @@ def test_get_opr_seq():
assert len(seq_2) == 6 assert len(seq_2) == 6
def test_topological_sort():
@trace(symbolic=True, capture_as_const=True)
def func(x, y):
a = x + y
a1 = F.relu(a)
a2 = F.abs(a)
a3 = F.ceil(a) * 2
a4 = F.floor(a)
r = a1 - a2
r1 = a3 / a4
return r, r1
file = io.BytesIO()
func(megengine.tensor(1.0), megengine.tensor(2.0))
func.dump(
file, optimize_for_inference=False, keep_opr_name=True, keep_opr_priority=True
)
file.seek(0)
g = Network.load(file)
oprseq1 = g.all_oprs
gt = [
"Host2DeviceCopy",
"Host2DeviceCopy",
"ADD",
"RELU",
"ABS",
"CEIL",
"ImmutableTensor",
"MUL",
"FLOOR",
"SUB",
"TRUE_DIV",
]
for op, mode in zip(oprseq1, gt):
if op.type == "Elemwise":
assert op.params["mode"] == mode
else:
assert op.type == mode
def test_graph_function(): def test_graph_function():
class Net(M.Module): class Net(M.Module):
def forward(self, a, b): def forward(self, a, b):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册