提交 90107b6d 编写于 作者: M Megvii Engine Team

feat(mge/cgtools): add network vistior interface with optional pruning

GitOrigin-RevId: cfa69e3e83ecbf32d5c4827b4562dc8f65b5d674
上级 270b7488
...@@ -67,6 +67,145 @@ def get_type(var): ...@@ -67,6 +67,145 @@ def get_type(var):
return _mgb._get_owner_opr_type(var) return _mgb._get_owner_opr_type(var)
def get_opr_type(opr):
"""get the type of a opr
:type var: :class:`.Operator`
:rtype: ``str``
"""
assert isinstance(opr, _mgb.Operator)
return _mgb._get_opr_type(opr)
def graph_traversal(outputs):
"""helper function to traverse the computing graph and reeturn enough useful information
:param outputs: model outputs
:type outputs: :class:`.Symbolvar`
:return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
WHERE
map_oprs is dict from opr_id to actual opr
map_vars is dict from var_id to actual var
var2oprs is dict from var to dest oprs along with index
opr2receivers is dict from current opr to next opr
indegree2opr is dict from in_degree to opr in computing graph
opr2indegree is dict from opr in computing graph to in_degree
(indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
"""
# meta information for comp graph
map_oprs = collections.defaultdict(set)
map_vars = collections.defaultdict(set)
var2oprs = collections.defaultdict(list)
opr2receivers = collections.defaultdict(list)
queue = list(map(lambda x: x.owner_opr, outputs))
visited = set(map(lambda x: x.id, queue))
# iterate through whole comp_graph, fill in meta information
indegree2opr = collections.defaultdict(set)
opr2indegree = {}
idx = 0
while idx < len(queue):
cur_opr = queue[idx]
map_oprs[cur_opr.id] = cur_opr
idx += 1
indegree = 0
for var_idx, var in enumerate(cur_opr.inputs):
map_vars[var.id] = var
var2oprs[var.id].append((cur_opr.id, var_idx))
pre_opr = var.owner_opr
if pre_opr.id not in visited:
visited.add(pre_opr.id)
queue.append(pre_opr)
indegree += 1
opr2receivers[pre_opr.id].append(cur_opr.id)
indegree2opr[indegree].add(cur_opr.id)
opr2indegree[cur_opr.id] = indegree
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
def get_oprs_seq(outputs, prune_reshape=False):
"""get oprs in some topological order for a dumped model
:param outputs: model outputs
:param prune_reshape: whether to prune the operators useless during inference
:return: opr list with some correct execution order
"""
def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
# generate an execution order with topological sort algorithm
oprs_seq = []
nr_remain = len(map_oprs)
while indegree2opr[0]:
opr_id = indegree2opr[0].pop()
opr = map_oprs[opr_id]
nr_remain -= 1
# skip const value generation operator
if get_opr_type(opr) != "ImmutableTensor":
oprs_seq.append(opr)
for post_id in opr2receivers[opr_id]:
indegree = opr2indegree[post_id]
indegree2opr[indegree].remove(post_id)
indegree -= 1
indegree2opr[indegree].add(post_id)
opr2indegree[post_id] = indegree
assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
nr_remain
)
return oprs_seq
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
def iterative_pruning(cur_opr, post_opr, marked_opr_ids):
useless = True
for oup in cur_opr.outputs:
if "workspace" not in oup.name:
var_idx = post_opr.inputs.index(oup)
var2oprs[oup.id].remove((post_opr.id, var_idx))
useless = useless and (len(var2oprs[oup.id]) == 0)
if useless:
marked_opr_ids.append(cur_opr.id)
for inp in cur_opr.inputs:
iterative_pruning(inp.owner_opr, cur_opr, marked_opr_ids)
reshape_vars = get_dep_vars(outputs, "Reshape")
reshape_oprs = [var.owner_opr for var in reshape_vars]
marked_opr_ids = []
for reshape_opr in reshape_oprs:
iterative_pruning(
reshape_opr.inputs[1].owner_opr, reshape_opr, marked_opr_ids
)
# filter out all marked oprs
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
outputs
)
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
if prune_reshape is True:
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
return oprs_seq
def replace_vars(dst, varmap): def replace_vars(dst, varmap):
"""replace vars in the graph """replace vars in the graph
......
...@@ -10,6 +10,10 @@ ...@@ -10,6 +10,10 @@
return var.node()->owner_opr()->dyn_typeinfo()->name; return var.node()->owner_opr()->dyn_typeinfo()->name;
} }
std::string _get_opr_type(Operator opr) {
return opr.node()->dyn_typeinfo()->name;
}
SymbolVarArray _replace_vars(const SymbolVarArray& repl_src, SymbolVarArray _replace_vars(const SymbolVarArray& repl_src,
const SymbolVarArray& repl_dst, const SymbolVarArray& repl_dst,
const SymbolVarArray& vars) { const SymbolVarArray& vars) {
......
...@@ -15,6 +15,7 @@ import pytest ...@@ -15,6 +15,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine._internal as mgb import megengine._internal as mgb
import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import functional as F from megengine import functional as F
from megengine import jit, tensor from megengine import jit, tensor
...@@ -148,6 +149,49 @@ def test_dump_volatile(): ...@@ -148,6 +149,49 @@ def test_dump_volatile():
assert mgb.cgtools.get_type(mgb.cgtools.get_inputs(out)[1]) == "SharedDeviceTensor" assert mgb.cgtools.get_type(mgb.cgtools.get_inputs(out)[1]) == "SharedDeviceTensor"
def test_graph_traversal():
net = M.Conv2d(3, 4, 3, 1, 1, groups=1, bias=False)
net.eval()
@jit.trace(symbolic=True)
def fun(data):
return net(data)
data = np.random.random([1, 3, 224, 224]).astype(np.float32)
fun.trace(data)
with mkstemp() as out:
fun.dump(out)
*_, outputs = mgb.load_comp_graph_from_file(out)
_, map_vars, var2oprs, *_ = mgb.cgtools.graph_traversal(outputs)
input_var = map_vars[1]
_, var_idx = var2oprs[input_var.id][0]
assert var_idx == 0
def test_network_visitor():
@jit.trace(symbolic=True)
def f(x):
# this line will produce shape_of, subtensor and concat op
# after pruning, they will be deleted
target_shape = (x.shape[0], -1)
return x.reshape(*target_shape)
f.trace(tensor(np.random.random([2, 3, 4, 5]).astype(np.float32)))
with mkstemp() as out:
f.dump(out)
*_, outputs = mgb.load_comp_graph_from_file(out)
all_oprs = mgb.cgtools.get_oprs_seq(outputs)
pruned_oprs = mgb.cgtools.get_oprs_seq(outputs, prune_reshape=True)
assert len(all_oprs) == len(pruned_oprs) + 3
def test_shape_tracing(): def test_shape_tracing():
for symbolic in [False, True]: for symbolic in [False, True]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册