提交 b1c46ba4 编写于 作者: M Megvii Engine Team

feat(traced_module): add some functions of graph modification

GitOrigin-RevId: ac0603057adaedf864f2d0ceb7bfb6d3c5a50640
上级 4bb25369
......@@ -9,6 +9,7 @@
import builtins
import collections
import inspect
from typing import Callable, List
from ...core._imperative_rt import OpDef
......@@ -16,10 +17,10 @@ from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.special import Const
from ...module import Module
from ...tensor import Tensor
from ...tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef
from .pytree import TreeDef, tree_flatten
class Expr:
......@@ -38,22 +39,25 @@ class Expr:
for val in vals:
node = NodeMixin.get(val, None)
if isinstance(node, (TensorNode, ModuleNode)):
if node not in self.inputs:
self.inputs.append(node)
node.users.append(self)
else:
assert node is None
assert type(val) in builtins.__dict__.values()
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))
def add_outputs(self, outputs):
def add_outputs(self, outputs, check_inplace=True):
self.outputs = []
if outputs is not None:
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)
for i in outputs:
assert isinstance(i, RawTensor)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
node = NodeMixin.get(i, None) if check_inplace else None
self.outputs.append(
node if node else NodeMixin.get_wrapped_type(i)(self)
)
for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)
......@@ -110,6 +114,7 @@ class GetAttr(Expr):
self.inputs = [
module,
]
module.users.append(self)
self.name = name
node_cls = type if type else Node
self.outputs = [
......@@ -134,10 +139,18 @@ class GetAttr(Expr):
# expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr):
def __init__(self, module, method="__call__"):
assert isinstance(module, (TensorNode, ModuleNode))
def __init__(self, node, method="__call__"):
if isinstance(node, type):
assert issubclass(node, Tensor)
cls = Parameter if issubclass(node, Parameter) else Tensor
self.inputs = []
self.const_val = [(0, cls)]
else:
assert isinstance(node, (TensorNode, ModuleNode))
node.users.append(self)
self.inputs = [
module,
node,
]
self.const_val = []
self.method = method
......@@ -160,10 +173,13 @@ class CallMethod(Expr):
def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs)
obj = args[0]
meth = getattr(obj, self.method)
if inspect.ismethod(meth):
args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs)
if isinstance(outputs, RawTensor):
outputs = (outputs,)
if outputs is None:
return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs
def __repr__(self):
......@@ -171,7 +187,7 @@ class CallMethod(Expr):
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format(
", ".join(str(i) for i in self.outputs),
self.inputs[0],
self.args[0],
self.method,
", ".join([args, kwargs]),
)
......@@ -209,9 +225,8 @@ class Apply(Expr):
if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i))
apply_node = cls.make(opdef)
for i in inputs:
assert isinstance(i, RawTensor)
apply_node.inputs.append(NodeMixin.get(i))
apply_node.add_inputs(inputs)
assert not apply_node.const_val
unset_module_tracing()
outputs = apply(opdef, *inputs)
......@@ -283,7 +298,7 @@ class Constant(Expr):
return (self.value,)
def __repr__(self):
return "{} = Constant({})".format(self.outputs[0], self.value)
return "{} = Constant({})".format(self.outputs[0], type(self.value))
def __getstate__(self):
state = self.__dict__.copy()
......
......@@ -79,6 +79,8 @@ BUILTIN_ARRAY_METHOD = [
"min",
"max",
"mean",
"__getitem__",
"__setitem__",
]
......@@ -176,7 +178,8 @@ class Patcher:
self.patch_module(module)
for meth in BUILTIN_ARRAY_METHOD:
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
self.patch_method(Tensor, "detach", self.wrap_fn)
self.patch_method(Tensor, "__new__", self.wrap_fn)
for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn)
......@@ -203,7 +206,13 @@ class Patcher:
import inspect
if id(module.__dict__) not in self.visited_frames_ids:
for k, v in module.__dict__.items():
keys = (
getattr(module, "__all__")
if hasattr(module, "__all__")
else module.__dict__.keys()
)
for k in keys:
v = getattr(module, k)
if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(module.__dict__, k, self.wrap_fn)
self.visited_frames_ids.add(id(module.__dict__))
......
......@@ -6,7 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Any, Dict, Tuple, Type
from typing import Any, Dict, List, Tuple, Type
import numpy
......@@ -31,6 +31,7 @@ class Node:
def __init__(self, expr: "Expr", name: str = None):
self.expr = expr
self.users = [] # List[Expr]
self._id = Node.__total_id
Node.__total_id += 1
self._name = name
......@@ -59,11 +60,13 @@ class ModuleNode(Node):
module_type = Module # type: Type[Module]
attr_type_map = None # type: Dict[str, Type[Any]]
argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
argdef_outdef_map = None # type: Dict[Treedef, Treedef]
def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.attr_type_map = {}
self.argdef_graph_map = {}
self.argdef_outdef_map = {}
def __repr__(self):
if self._name is None:
......
......@@ -10,6 +10,8 @@
import collections
from typing import Callable, NamedTuple
import numpy as np
SUPPORTED_TYPE = {}
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
......@@ -33,7 +35,7 @@ def _dict_unflatten(inps, aux_data):
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type(
slice,
......@@ -52,6 +54,9 @@ def tree_flatten(
assert is_leaf(values)
node = LeafDef(leaf_type(values))
if is_const_leaf(values):
if isinstance(values, np.ndarray):
node.const_val = str(values)
else:
node.const_val = values
return [values,], node
......
......@@ -10,8 +10,13 @@ import collections
import copy
import functools
from inspect import getmembers, isclass, ismethod
from typing import Dict, List, Type
from typing import Callable, Dict, Iterable, List, Sequence, Type
import numpy as np
from numpy.lib.arraysetops import isin
from ... import functional as F
from ... import get_logger
from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import (
......@@ -19,6 +24,7 @@ from ...core._imperative_rt.core2 import (
set_module_tracing,
unset_module_tracing,
)
from ...core._trace_option import set_symbolic_shape
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
from ...tensor import Tensor
......@@ -32,6 +38,8 @@ from .module_tracer import (
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import tree_flatten
logger = get_logger(__name__)
def _leaf_type(node):
if isinstance(node, RawTensor):
......@@ -42,6 +50,11 @@ def _leaf_type(node):
return type(node)
def _is_leaf(node):
assert isinstance(node, RawTensor), type(node)
return isinstance(node, RawTensor)
def _is_const_leaf(node):
if isinstance(node, (RawTensor, NodeMixin, Module)):
return False
......@@ -80,7 +93,13 @@ class InternalGraph:
@property
def exprs(self):
return _expr_list(self)
return ExprFilter(_expr_iter(self))
def get_call_function(self, func: Callable = None):
return self.exprs.call_function(func)
def get_call_method(self, method: str = None):
return self.exprs.call_method(method)
def add_input(self, i):
self._inputs.append(i)
......@@ -88,14 +107,129 @@ class InternalGraph:
def add_output(self, o):
self._outputs.append(o)
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence):
nodes = (nodes,)
ret = list()
queue = list(nodes)
while queue:
node = queue.pop()
expr = node.expr
if expr not in ret:
ret.append(expr)
for i in expr.inputs:
if i not in queue:
queue.append(i)
return ret
def insert_call_function(self, func: Callable, nodes: Sequence[Node]):
if not isinstance(nodes, Sequence):
nodes = [nodes]
assert isinstance(func, Callable)
for i in nodes:
assert isinstance(
i, TensorNode
), "CallFunction only accept TensorNode as inputs"
expr = CallFunction(func)
expr.inputs = nodes
for i in nodes:
i.users.append(expr)
idx = max(self._exprs.index(i.expr) for i in nodes) + 1
self._exprs.insert(idx, expr)
fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes)
fake_out_val = func(*fake_inp_val)
def create_node(val: Tensor):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node
out_nodes = list(create_node(i) for i in fake_out_val)
expr.outputs = out_nodes
return out_nodes
def insert_call_method(self, target, method, args):
if not isinstance(args, Sequence):
args = [args]
assert isinstance(target, (TensorNode, ModuleNode))
assert isinstance(method, str)
for i in args:
assert isinstance(i, TensorNode)
expr = CallMethod(method)
expr.inputs = [target, *args]
if isinstance(target, TensorNode):
fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype)
fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args)
fake_out_val = getattr(fake_target_val, method)(fake_inp_val)
def create_node(val: Tensor):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node
out_nodes = list(create_node(i) for i in fake_out_val)
expr.outputs = out_nodes
else:
raise NotImplementedError()
return out_nodes
def replace_node(self, repl_dict: Dict[Node, Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
# check graph inputs and outputs
assert node not in self.inputs, "Cannot replace inputs"
for i, n in enumerate(self.outputs):
if n is node:
self.outputs[i] = repl_node
# update users of node and repl_node
# update inputs of expr in node.users
dep_exprs = self.get_dep_exprs(repl_node)
i = 0
while i < len(node.users):
n = node.users[i]
if n in dep_exprs:
logger.info("Find a loop: ignore this replacement once")
logger.info("node: %s" % node.__repr__())
logger.info("repl_node: %s" % repl_node.__repr__())
i += 1
continue
repl_node.users.append(n)
node.users.pop(i)
idx = n.inputs.index(node)
n.inputs[idx] = repl_node
def compile(self):
"""
Delete unused expr.
"""
dep_exprs = self.get_dep_exprs(self.outputs)
i = 0
while i < len(self._exprs):
expr = self._exprs[i]
if expr in dep_exprs:
i += 1
continue
for n in expr.inputs:
n.users.remove(expr)
self._exprs.remove(expr)
def interpret(self, *inputs):
# TODO: support kwargs ?
# TODO: skip expressions which are independent and have no side effect
node2value = {}
for n, v in zip(self._inputs, inputs):
node2value[n] = v
for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
if values is not None:
for n, v in zip(expr.outputs, values):
node2value[n] = v
return list(node2value[i] for i in self._outputs)
......@@ -109,7 +243,8 @@ class InternalGraph:
def _get_meth_name(obj, func):
for cls in type(obj).mro():
tp = obj if isinstance(obj, type) else type(obj)
for cls in tp.mro():
for k, v in cls.__dict__.items():
if v == func:
return k
......@@ -131,14 +266,30 @@ def _wrapped_function(orig_func):
meth_name = _get_meth_name(args[0], wrapped_fn)
if meth_name:
self = inputs[0]
if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]):
# only trace Tensor.__new__() when there are tensors in args
set_module_tracing()
return orig_func(*args, **kwargs)
if isinstance(args[1], RawTensor):
node = NodeMixin.get(inputs[1])
inputs[1] = copy.copy(inputs[1])
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
NodeMixin.wrap_safe(inputs[1], node)
args, kwargs = tree_def.unflatten(inputs)
call_node = CallMethod.make(self, meth_name)
else:
call_node = CallMethod.make(NodeMixin.get(self), meth_name)
call_node.add_inputs(inputs[1:])
else:
call_node = CallFunction.make(orig_func)
call_node.add_inputs(inputs)
call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
if meth_name == "__new__":
call_node.add_outputs(outputs, False)
else:
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
......@@ -197,13 +348,14 @@ class TracedModuleBuilder(NodeMixin):
mark_constant(i)
callnode = CallMethod.make(NodeMixin.get(self))
callnode.add_inputs(inputs)
callnode.add_inputs(inputs[1:])
callnode.arg_def = tree_def
if self._is_builtin:
unset_module_tracing()
outputs = self._mod(*args, **kwargs)
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
set_module_tracing()
if self._is_builtin:
self._body = None
......@@ -215,14 +367,13 @@ class TracedModuleBuilder(NodeMixin):
NodeMixin.wrap_safe(
self, Input.make("self", NodeMixin.get_wrapped_type(self))
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph
def wrap(x):
wrapped = copy.copy(x) # FIXME
NodeMixin.wrap(
wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
return wrapped
return x
args = [self]
for i in inputs[1:]:
......@@ -231,21 +382,25 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
outputs = type(self._mod).forward(*args, **kwargs)
rst = type(self._mod).forward(*args, **kwargs)
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
):
active_module_tracer().current_scope().add_output(NodeMixin.get(i))
NodeMixin.wrap_safe(self, orig_self)
for arg, node in zip(inputs[1:], origin_inp_node):
if node:
NodeMixin.wrap_safe(arg, node)
active_module_tracer().pop_scope()
# rebind output to outer graph
callnode.add_outputs(outputs)
self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body
return outputs
self_node.argdef_outdef_map[callnode.arg_def] = out_def
return rst
def __getattr__(self, name):
if name not in self._mod.__dict__:
......@@ -268,7 +423,8 @@ class TracedModuleBuilder(NodeMixin):
return super().__getattribute__(name)
else:
wrapped = super().__getattribute__(name)
if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None):
if name in self._mod.__dict__:
if not NodeMixin.get(wrapped, None):
assert not self._is_builtin
NodeMixin.wrap(
wrapped,
......@@ -278,10 +434,18 @@ class TracedModuleBuilder(NodeMixin):
type=NodeMixin.get_wrapped_type(wrapped),
),
)
else:
node = NodeMixin.get(wrapped)
expr = GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
).expr
expr.outputs[0] = node
return wrapped
class _expr_list:
class _expr_iter:
def __init__(self, graph: InternalGraph):
self.graph = graph
......@@ -295,6 +459,59 @@ class _expr_list:
yield expr
class ExprFilter:
def __init__(self, expr_iter: Iterable):
self._iter = expr_iter
def __iter__(self):
return iter(self._iter)
def call_function(self, func):
return ExprFilterCallFunction(self, func)
def call_method(self, method):
return ExprFilterCallMethod(self, method)
def as_list(self):
return list(self)
def as_dict(self):
raise NotImplementedError("need key")
def as_unique(self):
(expr,) = self
return expr
def as_count(self):
return sum(1 for _ in self)
class ExprFilterCallFunction(ExprFilter):
def __init__(self, expr_iter, func: Callable = None):
super().__init__(expr_iter)
self.func = func
def __iter__(self):
for i in self._iter:
if not isinstance(i, CallFunction):
continue
if self.func is None or i.func == self.func:
yield i
class ExprFilterCallMethod(ExprFilter):
def __init__(self, expr_iter, method: str = None):
super().__init__(expr_iter)
self.method = method
def __iter__(self):
for i in self._iter:
if not isinstance(i, CallMethod):
continue
if self.method is None or i.method == self.method:
yield i
class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
......@@ -312,10 +529,12 @@ class TracedModule(Module):
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
)
assert treedef in self.m_node.argdef_graph_map
inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))]
inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.
outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
if len(outputs) == 1:
return outputs[0]
out_def = self.m_node.argdef_outdef_map[treedef]
outputs = out_def.unflatten(outputs)
return outputs
@property
......@@ -339,9 +558,8 @@ class TracedModule(Module):
if graph is None:
assert not isinstance(module, TracedModule)
const = Constant(module)
modulenode = const.outputs[0]
modulenode.module_type = type(module)
call.inputs[0] = modulenode
const.outputs[0] = call.inputs[0]
const.outputs[0].expr = const
return [const, call]
exprs = []
for expr in graph._exprs:
......@@ -350,30 +568,41 @@ class TracedModule(Module):
if call and inp in graph._inputs:
inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx]
call.inputs[inp_idx].users.append(expr)
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx]
call.outputs[oup_idx].expr = expr
if isinstance(expr, GetAttr):
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module, expr.name))
const.outputs = expr.outputs
const.outputs[0].expr = const
exprs.append(const)
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
assert isinstance(expr.inputs[0].expr, GetAttr)
pre_expr = expr.inputs[0].expr
if isinstance(pre_expr, GetAttr):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
else:
# module has been replaced.
assert isinstance(pre_expr, Constant)
else:
exprs.append(expr)
else:
exprs.append(expr)
if call is not None:
for i in call.inputs:
i.users.remove(call)
return exprs
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
......@@ -422,16 +651,19 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
"""
assert active_module_tracer() is None
try:
use_sym_shape = set_symbolic_shape(True)
set_module_tracing()
set_active_module_tracer(module_tracer(_wrapped_function))
with active_module_tracer().patcher:
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs))
inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf)
for _, i in enumerate(inputs):
if isinstance(i, RawTensor):
NodeMixin.wrap_safe(
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
)
......@@ -439,5 +671,6 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
active_module_tracer().pop_scope()
return builder.build()
finally:
set_symbolic_shape(use_sym_shape)
set_active_module_tracer(None)
unset_module_tracing()
import io
import pickle
import numpy as np
import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
set_symbolic_shape(True)
class Main(M.Module):
def forward(self, x):
return x
class PreProcess(M.Module):
def __init__(self):
super().__init__()
self.I = F.ones((1,))
self.M = F.zeros((1,))
def forward(self, data, idx, roi):
N, H, W, C = data.shape
xmax = roi[:, 1, 0]
xmin = roi[:, 0, 0]
ymax = roi[:, 1, 1]
ymin = roi[:, 0, 1]
scale = F.maximum((xmax - xmin) / W, (ymax - ymin) / H)
I = F.broadcast_to(self.I, (N,))
M = F.broadcast_to(self.M, (N, 3, 3))
M[:, 0, 0] = scale
M[:, 0, 2] = xmin
M[:, 1, 1] = scale
M[:, 1, 2] = ymin
M[:, 2, 2] = I
resized = (
F.warp_perspective(
data, M, (H, W), mat_idx=idx, border_mode="CONSTANT", format="NHWC"
)
.transpose(0, 3, 1, 2)
.astype(np.float32)
)
return resized
class Net(M.Module):
def __init__(self, traced_module):
super().__init__()
self.pre_process = PreProcess()
self.traced_module = traced_module
def forward(self, data, idx, roi):
x = self.pre_process(data, idx, roi)
x = self.traced_module(x)
return x
def test_preprocess():
module = Main()
data = F.ones((1, 14, 8, 8), dtype=np.uint8)
traced_module = trace_module(module, data)
obj = pickle.dumps(traced_module)
traced_module = pickle.loads(obj)
module = Net(traced_module)
module.eval()
idx = F.zeros((1,), dtype=np.int32)
roi = F.ones((1, 2, 2), dtype=np.float32)
y = module(data, idx, roi)
traced_module = trace_module(module, data, idx, roi)
np.testing.assert_array_equal(traced_module(data, idx, roi), y)
func = trace(traced_module, capture_as_const=True)
np.testing.assert_array_equal(func(data, idx, roi), y)
model = io.BytesIO()
func.dump(model, arg_names=("data", "idx", "roi"))
model.seek(0)
infer_cg = cgtools.GraphInference(model)
np.testing.assert_allclose(
list(
infer_cg.run(
inp_dict={"data": data.numpy(), "idx": idx.numpy(), "roi": roi.numpy()}
).values()
)[0],
y,
atol=1e-6,
)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine.functional as F
import megengine.module as M
from megengine.experimental.traced_module import trace_module
from megengine.experimental.traced_module.expr import CallFunction, GetAttr
class MyBlock(M.Module):
def __init__(self, in_channels=3, channels=3):
super(MyBlock, self).__init__()
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(channels)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x) + 1
return x
class MyModule(M.Module):
def __init__(self):
super(MyModule, self).__init__()
self.block0 = MyBlock()
self.block1 = MyBlock()
def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x
def _init_cls(cls):
module = cls()
x = F.ones((1, 3, 3, 3))
y = module(x)
traced_module = trace_module(module, x)
return traced_module, x, y
def _init_block():
return _init_cls(MyBlock)
def _init_module():
return _init_cls(MyModule)
def test_search():
traced_module, *_ = _init_block()
graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu
def test_insert():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_node = graph.get_call_function(F.relu).as_unique().outputs
neg_node = graph.insert_call_function(F.neg, relu_node)
graph.replace_node({relu_node[0]: neg_node[0]})
graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
def test_delete():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
node = relu_expr.outputs
repl_node = relu_expr.inputs
graph.replace_node({node[0]: repl_node[0]})
graph.compile()
np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6)
def test_flatten():
traced_module, x, expect = _init_module()
traced_module = traced_module.flatten()
traced_module.graph.compile()
assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs)
assert len(traced_module.graph._exprs) == 12
def test_extra_block():
class PostProcess(M.Module):
def forward(self, x):
return x * 2
class Net(M.Module):
def __init__(self, traced_module):
super().__init__()
self.post_process = PostProcess()
self.traced_module = traced_module
def forward(self, x):
x = self.traced_module(x)
x = self.post_process(x)
return x
traced_module, x, expect = _init_block()
module = Net(traced_module)
np.testing.assert_allclose(2 * expect, module(x), atol=1e-6)
traced_module = trace_module(module, x)
np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle
import numpy as np
import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Module
class MyBlock(Module):
def __init__(self, in_channels, channels):
super(MyBlock, self).__init__()
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(channels)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x) + 1
return x
class MyModule(Module):
def __init__(self):
super(MyModule, self).__init__()
self.block0 = MyBlock(8, 4)
self.block1 = MyBlock(4, 2)
def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x
def test_dump_and_load():
module = MyModule()
x = Tensor(np.ones((1, 8, 14, 14)))
expect = module(x)
traced_module = trace_module(module, x)
np.testing.assert_array_equal(expect, traced_module(x))
obj = pickle.dumps(traced_module)
pickle.loads(obj)
np.testing.assert_array_equal(expect, traced_module(x))
import numpy as np
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Module as M
class MyModule1(M):
def forward(self, x):
y = Tensor(x)
y += 1
x = x + 2
return x, y
class MyModule2(M):
def forward(self, x):
y = Tensor([1, x, 1])
y += 1
x = x + 2
return x, y
def test_trace_module():
x = Tensor(1)
m1 = MyModule1()
tm1 = trace_module(m1, x)
m2 = MyModule2()
tm2 = trace_module(m2, x)
inp = Tensor(2)
gt = m1(inp)
output = tm1(inp)
for a, b in zip(output, gt):
np.testing.assert_equal(a.numpy(), b.numpy())
gt1 = m2(inp)
output1 = tm2(inp)
for a, b in zip(output1, gt1):
np.testing.assert_equal(a.numpy(), b.numpy())
import io
import pickle
import numpy as np
import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
set_symbolic_shape(True)
class Main(M.Module):
def forward(self, x):
return x["data"]
class PreProcess(M.Module):
def __init__(self):
super().__init__()
self.A = F.zeros((1,))
self.I = F.ones((1,))
self.bb_out = mge.tensor(
np.array([[[0, 0], [160, 0], [160, 48], [0, 48]]], dtype="float32")
)
def forward(self, data, quad):
"""
data: (1, 3, 48, 160)
quad: (1, 4, 2)
"""
N = quad.shape[0]
dst = F.repeat(self.bb_out, N, axis=0).reshape(-1, 4, 2)
I = F.broadcast_to(self.I, quad.shape)
A = F.broadcast_to(self.A, (N, 8, 8))
A[:, 0:4, 0:2] = quad
A[:, 4:8, 5:6] = I[:, :, 0:1]
A[:, 0:4, 6:8] = -quad * dst[:, :, 0:1]
A[:, 4:8, 3:5] = quad
A[:, 0:4, 2:3] = I[:, :, 0:1]
A[:, 4:8, 6:8] = -quad * dst[:, :, 1:2]
B = dst.transpose(0, 2, 1).reshape(-1, 8, 1)
M = F.concat([F.matmul(F.matinv(A), B)[:, :, 0], I[:, 0:1, 0]], axis=1).reshape(
-1, 3, 3
)
new_data = F.warp_perspective(data, M, (48, 160)) # (N, 3, 48, 160)
return {"data": new_data}
class Net(M.Module):
def __init__(self, traced_module):
super().__init__()
self.pre_process = PreProcess()
self.traced_module = traced_module
def forward(self, data, quad):
x = self.pre_process(data, quad)
x = self.traced_module(x)
return x
def test_preprocess():
batch_size = 2
module = Main()
data = mge.tensor(
np.random.randint(0, 256, size=(batch_size, 3, 48, 160)), dtype=np.float32
)
traced_module = trace_module(module, {"data": data})
obj = pickle.dumps(traced_module)
traced_module = pickle.loads(obj)
module = Net(traced_module)
module.eval()
quad = mge.tensor(np.random.normal(size=(batch_size, 4, 2)), dtype=np.float32)
expect = module(data, quad)
traced_module = trace_module(module, data, quad)
actual = traced_module(data, quad)
for i, j in zip(expect, actual):
np.testing.assert_array_equal(i, j)
func = trace(traced_module, capture_as_const=True)
actual = func(data, quad)
for i, j in zip(expect, actual):
np.testing.assert_array_equal(i, j)
model = io.BytesIO()
func.dump(model, arg_names=("data", "quad"))
model.seek(0)
infer_cg = cgtools.GraphInference(model)
actual = list(
infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values()
)[0]
np.testing.assert_allclose(expect, actual)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册