提交 8118a594 编写于 作者: M Megvii Engine Team

fix(mge/utils): fix get_oprs_seq of cgtools

GitOrigin-RevId: 366a56f4d5b7b607d4f14b9102c387392a3a5936
上级 ae8c3c81
......@@ -78,7 +78,6 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save
from .tensor import Parameter, Tensor, tensor
from .version import __version__
from .utils import comp_graph_tools as cgtools
_set_fork_exec_path_for_timed_func(
sys.executable,
......
......@@ -15,6 +15,19 @@ from ..core._imperative_rt import OperatorNode, VarNode
from ..core.tensor import megbrain_graph as G
from ..core.tensor.raw_tensor import as_raw_tensor
__all__ = [
"get_dep_vars",
"get_owner_opr_inputs",
"get_owner_opr_type",
"get_opr_type",
"graph_traversal",
"get_oprs_seq",
"replace_vars",
"replace_oprs",
"set_priority_to_id",
"load_and_inference",
]
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
"""
......@@ -166,7 +179,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
def iterative_pruning(cur_opr, post_opr, marked_opr_ids):
def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited):
useless = True
for oup in cur_opr.outputs:
if "workspace" not in oup.name:
......@@ -177,15 +190,20 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo
if useless:
marked_opr_ids.append(cur_opr.id)
for inp in cur_opr.inputs:
iterative_pruning(inp.owner, cur_opr, marked_opr_ids)
for opr in set([var.owner for var in cur_opr.inputs]):
if (opr.id, cur_opr.id) not in visited:
visited.add((opr.id, cur_opr.id))
iterative_pruning(opr, cur_opr, marked_opr_ids, visited)
reshape_vars = get_dep_vars(outputs, "Reshape")
reshape_oprs = [var.owner for var in reshape_vars]
marked_opr_ids = []
visited = set()
for reshape_opr in reshape_oprs:
iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids)
iterative_pruning(
reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited
)
# filter out all marked oprs
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
......
......@@ -13,9 +13,10 @@ import pytest
import megengine
import megengine.functional as F
import megengine.module as M
from megengine import cgtools
import megengine.utils.comp_graph_tools as cgtools
from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.tensor.utils import astensor1d
from megengine.jit import trace
......@@ -98,3 +99,38 @@ def test_load_refcnt():
graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf))
del graph
varnode.owner
def test_get_opr_seq():
class Net(M.Module):
def __init__(self):
super().__init__()
self.data = megengine.tensor(
np.random.random((1, 1, 4, 4)), dtype=np.float32
)
def forward(self, input):
A = input.shape[0]
shape = astensor1d((A, A), self.data, dtype="int32", device=input.device)
x = F.reshape(self.data, shape)
o = input + x
return o
net = Net()
input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32)
@trace(symbolic=True, capture_as_const=True)
def func(inp, *, net=None):
return net(inp)
func(input, net=net)
file = io.BytesIO()
func.dump(file, optimize_for_inference=False)
file.seek(0)
*_, outputs = mgb_graph.load_graph(file)
seq_1 = cgtools.get_oprs_seq(outputs, True)
assert len(seq_1) == 5
seq_2 = cgtools.get_oprs_seq(outputs, False)
assert len(seq_2) == 6
......@@ -14,7 +14,8 @@ import pytest
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
from megengine import cgtools, tensor
import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor
from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册