提交 2d42455f 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/utils): fix toposort to get definition order

GitOrigin-RevId: 47a26dd6dda31d349f7439c64a72027e6a9a7391
上级 0c97b2a3
......@@ -893,6 +893,10 @@ class trace:
if isinstance(file, str):
permission = "wb" if append == False else "ab"
file = open(file, permission)
if keep_opr_priority:
graph._set_priority_to_id(dest_vars)
dump_content, dump_info = G.dump_graph(
dest_vars,
keep_var_name=keep_var_name,
......
......@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import heapq
from collections import OrderedDict
from typing import Dict, List, Tuple, Union
......@@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str:
return opr.type
class _OprStableOrderHeapq:
"""heap implementation for operator comparison in stable order"""
_list = None
_extra_priority = None
_used_id_name_pairs = None
def __init__(self, extra_priority):
assert isinstance(extra_priority, collections.Callable)
self._list = []
self._extra_priority = extra_priority
self._used_id_name_pairs = {}
def pop_min(self):
return heapq.heappop(self._list)[-1]
def add(self, opr):
# named as add to mimic set() interface
id_ = opr.id
name = opr.name
other = self._used_id_name_pairs.setdefault((id_, name), opr)
if other is not opr:
raise RuntimeError(
"duplicated (id, name) pair: opr0={} opr1={}".format(other, opr)
)
item = self._extra_priority(opr) + (id_, name, opr)
heapq.heappush(self._list, item)
def __bool__(self):
return bool(self._list)
def graph_traversal(outputs: _VarNode):
"""
Helper function to traverse the computing graph and return enough useful information.
......@@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode):
var2oprs = collections.defaultdict(list)
opr2receivers = collections.defaultdict(list)
queue = list(set(map(lambda x: x.owner, outputs)))
queue = []
[queue.append(o) for o in [x.owner for x in outputs] if o not in queue]
visited = set(map(lambda x: x.id, queue))
# iterate through whole comp_graph, fill in meta information
indegree2opr = collections.defaultdict(set)
indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,))
opr2indegree = {}
idx = 0
......@@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode):
indegree += 1
opr2receivers[pre_opr.id].append(cur_opr.id)
indegree2opr[indegree].add(cur_opr.id)
opr = cur_opr if indegree == 0 else cur_opr.id
indegree2opr[indegree].add(opr)
opr2indegree[cur_opr.id] = indegree
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
......@@ -162,8 +199,8 @@ def get_oprs_seq(
oprs_seq = []
nr_remain = len(map_oprs)
while indegree2opr[0]:
opr_id = indegree2opr[0].pop()
opr = map_oprs[opr_id]
opr = indegree2opr[0].pop_min()
opr_id = opr.id
nr_remain -= 1
if opr.type != "ImmutableTensor" or not prune_immtensor:
oprs_seq.append(opr)
......@@ -173,6 +210,9 @@ def get_oprs_seq(
indegree2opr[indegree].remove(post_id)
indegree -= 1
if indegree == 0:
indegree2opr[indegree].add(map_oprs[post_id])
else:
indegree2opr[indegree].add(post_id)
opr2indegree[post_id] = indegree
......@@ -213,10 +253,34 @@ def get_oprs_seq(
# filter out all marked oprs
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
# adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs.
def reorder_oprs_seq(oprs):
rst = []
param_or_data_provider_oprs = []
other_oprs = []
for o in oprs:
if o.type in ["ImmutableTensor", "Host2DeviceCopy"]:
param_or_data_provider_oprs.append(o)
else:
other_oprs.append(o)
for o in other_oprs:
for inp in o.inputs:
if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]:
if inp.owner in param_or_data_provider_oprs:
rst.append(inp.owner)
param_or_data_provider_oprs.remove(inp.owner)
rst.append(o)
rst = rst + param_or_data_provider_oprs
assert len(rst) == len(oprs)
return rst
map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
outputs
)
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
oprs_seq = reorder_oprs_seq(oprs_seq)
if prune_reshape is True:
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
return oprs_seq
......
......@@ -241,6 +241,7 @@ class Network:
if optimize_for_inference:
metadata.optimize_options = optimize_options
G.set_priority_to_id([o._node if isinstance(o, G.VarNode) else o for o in out])
dump_content, _ = G.dump_graph(
out,
keep_var_name=keep_var_name,
......@@ -353,7 +354,7 @@ class Network:
)
shp[0] = batchsize
i.shape = tuple(shp)
self._compile()
assert prev_batchsize is not None, "no data provider found"
assert not blacklist, "unused items in blacklist: {}".format(blacklist)
......@@ -363,7 +364,6 @@ class Network:
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars.
"""
if not all([var.owner for var in repl_dict.values()]):
print(repl_dict.values())
self.add_dep_oprs(*list(repl_dict.values()))
for var in self.all_vars:
if var in repl_dict:
......@@ -373,6 +373,7 @@ class Network:
owner.outputs[idx] = var
var.__dict__.update(repl_var.__dict__)
var.var = repl_var.var
self._compile()
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
"""
......@@ -384,11 +385,11 @@ class Network:
assert len(opr.outputs) == len(
repl_dict[opr].outputs
), "can not replace {} with {}".format(type(opr), type(repl_dict[opr]))
repl_dict[opr].outputs = opr.outputs
for ind, var in enumerate(opr.outputs):
var.owner = repl_dict[opr]
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
var.var = repl_dict[opr].outputs[ind].var
self._compile()
def get_opr_by_type(self, oprcls, unique=True):
assert issubclass(oprcls, OpNode)
......
......@@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def dtype(self):
return self.var.dtype if self.var else None
@property
def ndim(self):
return super().ndim
def __bool__(self):
return False
......@@ -134,7 +138,18 @@ class OpNode(NetworkNode):
self.outputs = []
self.params = {}
self._opr = None # mgb opnode
self.id = id(self)
@property
def id(self):
if self._opr is not None:
return self._opr.id
return id(self)
@property
def priority(self):
if self._opr is not None:
return self._opr.priority
return 0
@classmethod
def load(cls, opr):
......@@ -144,7 +159,12 @@ class OpNode(NetworkNode):
obj._opr = opr
return obj
def compile(self, graph=None):
def compile(self):
if (
self._opr is None
or len(self._opr.inputs) != len(self.inputs)
or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)])
):
op = self.opdef(**self.params)
args = [i.var for i in self.inputs]
outputs = rt.invoke_op(op, args)
......@@ -197,6 +217,12 @@ class Host2DeviceCopy(OpNode):
return self
def compile(self, graph):
if (
self._opr is None
or self._opr.outputs[0].comp_node != self.device
or self._opr.outputs[0].shape != self.shape
or self._opr.outputs[0].dtype != self.dtype
):
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
self._opr = outputs.owner
if len(self.outputs) == 0:
......
......@@ -192,6 +192,13 @@ void init_graph_rt(py::module m) {
})
.def("__repr__", [](cg::OperatorNodeBase* opr){
return "Opr:" + opr->name();
})
.def_property("priority",
[](cg::OperatorNodeBase* opr) {
return opr->node_prop().attribute().priority;
},
[](cg::OperatorNodeBase* opr, int priority) {
opr->node_prop().attribute().priority = priority;
});
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
......
......@@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.core.tensor.megbrain_graph import apply_normal_varnode
from megengine.core.tensor.utils import astensor1d
from megengine.jit import trace
from megengine.utils.network import Network
def make_dev_tensor(value, dtype=None, device=None):
......@@ -143,6 +144,46 @@ def test_get_opr_seq():
assert len(seq_2) == 6
def test_topological_sort():
@trace(symbolic=True, capture_as_const=True)
def func(x, y):
a = x + y
a1 = F.relu(a)
a2 = F.abs(a)
a3 = F.ceil(a) * 2
a4 = F.floor(a)
r = a1 - a2
r1 = a3 / a4
return r, r1
file = io.BytesIO()
func(megengine.tensor(1.0), megengine.tensor(2.0))
func.dump(
file, optimize_for_inference=False, keep_opr_name=True, keep_opr_priority=True
)
file.seek(0)
g = Network.load(file)
oprseq1 = g.all_oprs
gt = [
"Host2DeviceCopy",
"Host2DeviceCopy",
"ADD",
"RELU",
"ABS",
"CEIL",
"ImmutableTensor",
"MUL",
"FLOOR",
"SUB",
"TRUE_DIV",
]
for op, mode in zip(oprseq1, gt):
if op.type == "Elemwise":
assert op.params["mode"] == mode
else:
assert op.type == mode
def test_graph_function():
class Net(M.Module):
def forward(self, a, b):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册