diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py index 361e19ec8bf559542e910016e48c1999d368d650..1cb592d456c5180ba955839820355c571e0caa50 100644 --- a/imperative/python/megengine/experimental/traced_module/expr.py +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -9,6 +9,7 @@ import builtins import collections +import inspect from typing import Callable, List from ...core._imperative_rt import OpDef @@ -16,10 +17,10 @@ from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing from ...core.ops.special import Const from ...module import Module -from ...tensor import Tensor +from ...tensor import Parameter, Tensor from .module_tracer import active_module_tracer, module_tracer from .node import ModuleNode, Node, NodeMixin, TensorNode -from .pytree import TreeDef +from .pytree import TreeDef, tree_flatten class Expr: @@ -38,25 +39,28 @@ class Expr: for val in vals: node = NodeMixin.get(val, None) if isinstance(node, (TensorNode, ModuleNode)): - if node not in self.inputs: - self.inputs.append(node) + self.inputs.append(node) + node.users.append(self) else: assert node is None - assert type(val) in builtins.__dict__.values() idx = len(self.inputs) + len(self.const_val) self.const_val.append((idx, val)) - def add_outputs(self, outputs): + def add_outputs(self, outputs, check_inplace=True): self.outputs = [] - if not isinstance(outputs, collections.Sequence): - outputs = (outputs,) + if outputs is not None: + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) - for i in outputs: - assert isinstance(i, RawTensor) - self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) + for i in outputs: + assert isinstance(i, RawTensor) + node = NodeMixin.get(i, None) if check_inplace else None + self.outputs.append( + node if node else NodeMixin.get_wrapped_type(i)(self) + ) - for i, node in zip(outputs, self.outputs,): - NodeMixin.wrap_safe(i, node) + for i, node in zip(outputs, self.outputs,): + NodeMixin.wrap_safe(i, node) def unflatten_args(self, inputs): if self.arg_def is not None: @@ -110,6 +114,7 @@ class GetAttr(Expr): self.inputs = [ module, ] + module.users.append(self) self.name = name node_cls = type if type else Node self.outputs = [ @@ -134,12 +139,20 @@ class GetAttr(Expr): # expr: outputs = inputs[0].__call__(*inputs[1:]) class CallMethod(Expr): - def __init__(self, module, method="__call__"): - assert isinstance(module, (TensorNode, ModuleNode)) - self.inputs = [ - module, - ] - self.const_val = [] + def __init__(self, node, method="__call__"): + if isinstance(node, type): + assert issubclass(node, Tensor) + cls = Parameter if issubclass(node, Parameter) else Tensor + + self.inputs = [] + self.const_val = [(0, cls)] + else: + assert isinstance(node, (TensorNode, ModuleNode)) + node.users.append(self) + self.inputs = [ + node, + ] + self.const_val = [] self.method = method @classmethod @@ -160,10 +173,13 @@ class CallMethod(Expr): def interpret(self, *inputs): args, kwargs = self.unflatten_args(inputs) obj = args[0] - args = args[1:] + meth = getattr(obj, self.method) + if inspect.ismethod(meth): + args = args[1:] outputs = getattr(obj, self.method)(*args, **kwargs) - if isinstance(outputs, RawTensor): - outputs = (outputs,) + if outputs is None: + return outputs + outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) return outputs def __repr__(self): @@ -171,7 +187,7 @@ class CallMethod(Expr): kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) return "{} = {}.{}({})".format( ", ".join(str(i) for i in self.outputs), - self.inputs[0], + self.args[0], self.method, ", ".join([args, kwargs]), ) @@ -209,9 +225,8 @@ class Apply(Expr): if node is None: # capture as constant NodeMixin.wrap_safe(i, Constant.make(i)) apply_node = cls.make(opdef) - for i in inputs: - assert isinstance(i, RawTensor) - apply_node.inputs.append(NodeMixin.get(i)) + apply_node.add_inputs(inputs) + assert not apply_node.const_val unset_module_tracing() outputs = apply(opdef, *inputs) @@ -283,7 +298,7 @@ class Constant(Expr): return (self.value,) def __repr__(self): - return "{} = Constant({})".format(self.outputs[0], self.value) + return "{} = Constant({})".format(self.outputs[0], type(self.value)) def __getstate__(self): state = self.__dict__.copy() diff --git a/imperative/python/megengine/experimental/traced_module/module_tracer.py b/imperative/python/megengine/experimental/traced_module/module_tracer.py index 57d69dbdbbf70feb5b46ab4d6052499fcbfa7639..6bdd65f951e9576e75722dacdd613dac63aff009 100644 --- a/imperative/python/megengine/experimental/traced_module/module_tracer.py +++ b/imperative/python/megengine/experimental/traced_module/module_tracer.py @@ -79,6 +79,8 @@ BUILTIN_ARRAY_METHOD = [ "min", "max", "mean", + "__getitem__", + "__setitem__", ] @@ -176,7 +178,8 @@ class Patcher: self.patch_module(module) for meth in BUILTIN_ARRAY_METHOD: self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) - + self.patch_method(Tensor, "detach", self.wrap_fn) + self.patch_method(Tensor, "__new__", self.wrap_fn) for i, j in self._builtin_functions: if id(i) not in self.visited_frames_ids: self.patch_function(i, j, self.wrap_fn) @@ -203,7 +206,13 @@ class Patcher: import inspect if id(module.__dict__) not in self.visited_frames_ids: - for k, v in module.__dict__.items(): + keys = ( + getattr(module, "__all__") + if hasattr(module, "__all__") + else module.__dict__.keys() + ) + for k in keys: + v = getattr(module, k) if inspect.isfunction(v) and not k.startswith("_"): self.patch_function(module.__dict__, k, self.wrap_fn) self.visited_frames_ids.add(id(module.__dict__)) diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py index 9a7436e979ce8caf49b0284451a7f7af72f309a0..bd1fc4c91173b73c37f0d3f3ab5db9c1cd3a8f06 100644 --- a/imperative/python/megengine/experimental/traced_module/node.py +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -6,7 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from typing import Any, Dict, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import numpy @@ -31,6 +31,7 @@ class Node: def __init__(self, expr: "Expr", name: str = None): self.expr = expr + self.users = [] # List[Expr] self._id = Node.__total_id Node.__total_id += 1 self._name = name @@ -59,11 +60,13 @@ class ModuleNode(Node): module_type = Module # type: Type[Module] attr_type_map = None # type: Dict[str, Type[Any]] argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"] + argdef_outdef_map = None # type: Dict[Treedef, Treedef] def __init__(self, expr: "Expr", name: str = None): super().__init__(expr, name) self.attr_type_map = {} self.argdef_graph_map = {} + self.argdef_outdef_map = {} def __repr__(self): if self._name is None: diff --git a/imperative/python/megengine/experimental/traced_module/pytree.py b/imperative/python/megengine/experimental/traced_module/pytree.py index f6c5d7ea45686854733b61c1f4dcdd4a20f78aff..74ca1933e06a75866fbfea2f8057b0aed78b07ae 100644 --- a/imperative/python/megengine/experimental/traced_module/pytree.py +++ b/imperative/python/megengine/experimental/traced_module/pytree.py @@ -10,6 +10,8 @@ import collections from typing import Callable, NamedTuple +import numpy as np + SUPPORTED_TYPE = {} NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) @@ -33,7 +35,7 @@ def _dict_unflatten(inps, aux_data): register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) -register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) +register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) register_supported_type(dict, _dict_flatten, _dict_unflatten) register_supported_type( slice, @@ -52,7 +54,10 @@ def tree_flatten( assert is_leaf(values) node = LeafDef(leaf_type(values)) if is_const_leaf(values): - node.const_val = values + if isinstance(values, np.ndarray): + node.const_val = str(values) + else: + node.const_val = values return [values,], node rst = [] diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index 9c80eb85561ea8829ab565a9278712bdcd44df8b..38869012be651a0a20b510110743402be7fe7380 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -10,8 +10,13 @@ import collections import copy import functools from inspect import getmembers, isclass, ismethod -from typing import Dict, List, Type +from typing import Callable, Dict, Iterable, List, Sequence, Type +import numpy as np +from numpy.lib.arraysetops import isin + +from ... import functional as F +from ... import get_logger from ... import module as M from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import ( @@ -19,6 +24,7 @@ from ...core._imperative_rt.core2 import ( set_module_tracing, unset_module_tracing, ) +from ...core._trace_option import set_symbolic_shape from ...core.tensor.array_method import ArrayMethodMixin from ...module import Module from ...tensor import Tensor @@ -32,6 +38,8 @@ from .module_tracer import ( from .node import ModuleNode, Node, NodeMixin, TensorNode from .pytree import tree_flatten +logger = get_logger(__name__) + def _leaf_type(node): if isinstance(node, RawTensor): @@ -42,6 +50,11 @@ def _leaf_type(node): return type(node) +def _is_leaf(node): + assert isinstance(node, RawTensor), type(node) + return isinstance(node, RawTensor) + + def _is_const_leaf(node): if isinstance(node, (RawTensor, NodeMixin, Module)): return False @@ -80,7 +93,13 @@ class InternalGraph: @property def exprs(self): - return _expr_list(self) + return ExprFilter(_expr_iter(self)) + + def get_call_function(self, func: Callable = None): + return self.exprs.call_function(func) + + def get_call_method(self, method: str = None): + return self.exprs.call_method(method) def add_input(self, i): self._inputs.append(i) @@ -88,16 +107,131 @@ class InternalGraph: def add_output(self, o): self._outputs.append(o) + def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: + if not isinstance(nodes, Sequence): + nodes = (nodes,) + ret = list() + queue = list(nodes) + while queue: + node = queue.pop() + expr = node.expr + if expr not in ret: + ret.append(expr) + for i in expr.inputs: + if i not in queue: + queue.append(i) + return ret + + def insert_call_function(self, func: Callable, nodes: Sequence[Node]): + if not isinstance(nodes, Sequence): + nodes = [nodes] + assert isinstance(func, Callable) + for i in nodes: + assert isinstance( + i, TensorNode + ), "CallFunction only accept TensorNode as inputs" + + expr = CallFunction(func) + expr.inputs = nodes + + for i in nodes: + i.users.append(expr) + + idx = max(self._exprs.index(i.expr) for i in nodes) + 1 + self._exprs.insert(idx, expr) + + fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes) + fake_out_val = func(*fake_inp_val) + + def create_node(val: Tensor): + node = TensorNode(expr) + node.shape = val.shape + node.dtype = val.dtype + return node + + out_nodes = list(create_node(i) for i in fake_out_val) + expr.outputs = out_nodes + + return out_nodes + + def insert_call_method(self, target, method, args): + if not isinstance(args, Sequence): + args = [args] + assert isinstance(target, (TensorNode, ModuleNode)) + assert isinstance(method, str) + for i in args: + assert isinstance(i, TensorNode) + + expr = CallMethod(method) + expr.inputs = [target, *args] + + if isinstance(target, TensorNode): + fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype) + fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args) + fake_out_val = getattr(fake_target_val, method)(fake_inp_val) + + def create_node(val: Tensor): + node = TensorNode(expr) + node.shape = val.shape + node.dtype = val.dtype + return node + + out_nodes = list(create_node(i) for i in fake_out_val) + expr.outputs = out_nodes + else: + raise NotImplementedError() + + return out_nodes + + def replace_node(self, repl_dict: Dict[Node, Node]): + while repl_dict: + node, repl_node = repl_dict.popitem() + # check graph inputs and outputs + assert node not in self.inputs, "Cannot replace inputs" + for i, n in enumerate(self.outputs): + if n is node: + self.outputs[i] = repl_node + # update users of node and repl_node + # update inputs of expr in node.users + dep_exprs = self.get_dep_exprs(repl_node) + i = 0 + while i < len(node.users): + n = node.users[i] + if n in dep_exprs: + logger.info("Find a loop: ignore this replacement once") + logger.info("node: %s" % node.__repr__()) + logger.info("repl_node: %s" % repl_node.__repr__()) + i += 1 + continue + repl_node.users.append(n) + node.users.pop(i) + idx = n.inputs.index(node) + n.inputs[idx] = repl_node + + def compile(self): + """ + Delete unused expr. + """ + dep_exprs = self.get_dep_exprs(self.outputs) + i = 0 + while i < len(self._exprs): + expr = self._exprs[i] + if expr in dep_exprs: + i += 1 + continue + for n in expr.inputs: + n.users.remove(expr) + self._exprs.remove(expr) + def interpret(self, *inputs): - # TODO: support kwargs ? - # TODO: skip expressions which are independent and have no side effect node2value = {} for n, v in zip(self._inputs, inputs): node2value[n] = v for expr in self._exprs: values = expr.interpret(*list(node2value[i] for i in expr.inputs)) - for n, v in zip(expr.outputs, values): - node2value[n] = v + if values is not None: + for n, v in zip(expr.outputs, values): + node2value[n] = v return list(node2value[i] for i in self._outputs) def __repr__(self): @@ -109,7 +243,8 @@ class InternalGraph: def _get_meth_name(obj, func): - for cls in type(obj).mro(): + tp = obj if isinstance(obj, type) else type(obj) + for cls in tp.mro(): for k, v in cls.__dict__.items(): if v == func: return k @@ -131,15 +266,31 @@ def _wrapped_function(orig_func): meth_name = _get_meth_name(args[0], wrapped_fn) if meth_name: self = inputs[0] - call_node = CallMethod.make(NodeMixin.get(self), meth_name) + if meth_name == "__new__": + if all([not isinstance(i, RawTensor) for i in inputs]): + # only trace Tensor.__new__() when there are tensors in args + set_module_tracing() + return orig_func(*args, **kwargs) + if isinstance(args[1], RawTensor): + node = NodeMixin.get(inputs[1]) + inputs[1] = copy.copy(inputs[1]) + # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing. + NodeMixin.wrap_safe(inputs[1], node) + args, kwargs = tree_def.unflatten(inputs) + call_node = CallMethod.make(self, meth_name) + else: + call_node = CallMethod.make(NodeMixin.get(self), meth_name) + call_node.add_inputs(inputs[1:]) else: call_node = CallFunction.make(orig_func) - - call_node.add_inputs(inputs) + call_node.add_inputs(inputs) call_node.arg_def = tree_def outputs = orig_func(*args, **kwargs) - call_node.add_outputs(outputs) + if meth_name == "__new__": + call_node.add_outputs(outputs, False) + else: + call_node.add_outputs(outputs) set_module_tracing() return outputs return orig_func(*args, **kwargs) @@ -197,13 +348,14 @@ class TracedModuleBuilder(NodeMixin): mark_constant(i) callnode = CallMethod.make(NodeMixin.get(self)) - callnode.add_inputs(inputs) + callnode.add_inputs(inputs[1:]) callnode.arg_def = tree_def if self._is_builtin: unset_module_tracing() - outputs = self._mod(*args, **kwargs) + rst = self._mod(*args, **kwargs) + outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf) set_module_tracing() if self._is_builtin: self._body = None @@ -215,14 +367,13 @@ class TracedModuleBuilder(NodeMixin): NodeMixin.wrap_safe( self, Input.make("self", NodeMixin.get_wrapped_type(self)) ) + origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] # prepare args and kwargs for inner graph def wrap(x): - wrapped = copy.copy(x) # FIXME NodeMixin.wrap( - wrapped, - lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), + x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)), ) - return wrapped + return x args = [self] for i in inputs[1:]: @@ -231,21 +382,25 @@ class TracedModuleBuilder(NodeMixin): active_module_tracer().patcher.auto_patch( getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) ) - outputs = type(self._mod).forward(*args, **kwargs) - + rst = type(self._mod).forward(*args, **kwargs) + outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf) for i in ( outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) ): active_module_tracer().current_scope().add_output(NodeMixin.get(i)) NodeMixin.wrap_safe(self, orig_self) + for arg, node in zip(inputs[1:], origin_inp_node): + if node: + NodeMixin.wrap_safe(arg, node) active_module_tracer().pop_scope() # rebind output to outer graph callnode.add_outputs(outputs) self_node = NodeMixin.get(self) self_node.argdef_graph_map[callnode.arg_def] = self._body - return outputs + self_node.argdef_outdef_map[callnode.arg_def] = out_def + return rst def __getattr__(self, name): if name not in self._mod.__dict__: @@ -268,20 +423,29 @@ class TracedModuleBuilder(NodeMixin): return super().__getattribute__(name) else: wrapped = super().__getattribute__(name) - if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None): - assert not self._is_builtin - NodeMixin.wrap( - wrapped, - lambda: GetAttr.make( + if name in self._mod.__dict__: + if not NodeMixin.get(wrapped, None): + assert not self._is_builtin + NodeMixin.wrap( + wrapped, + lambda: GetAttr.make( + NodeMixin.get(self), + name, + type=NodeMixin.get_wrapped_type(wrapped), + ), + ) + else: + node = NodeMixin.get(wrapped) + expr = GetAttr.make( NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(wrapped), - ), - ) + ).expr + expr.outputs[0] = node return wrapped -class _expr_list: +class _expr_iter: def __init__(self, graph: InternalGraph): self.graph = graph @@ -295,6 +459,59 @@ class _expr_list: yield expr +class ExprFilter: + def __init__(self, expr_iter: Iterable): + self._iter = expr_iter + + def __iter__(self): + return iter(self._iter) + + def call_function(self, func): + return ExprFilterCallFunction(self, func) + + def call_method(self, method): + return ExprFilterCallMethod(self, method) + + def as_list(self): + return list(self) + + def as_dict(self): + raise NotImplementedError("need key") + + def as_unique(self): + (expr,) = self + return expr + + def as_count(self): + return sum(1 for _ in self) + + +class ExprFilterCallFunction(ExprFilter): + def __init__(self, expr_iter, func: Callable = None): + super().__init__(expr_iter) + self.func = func + + def __iter__(self): + for i in self._iter: + if not isinstance(i, CallFunction): + continue + if self.func is None or i.func == self.func: + yield i + + +class ExprFilterCallMethod(ExprFilter): + def __init__(self, expr_iter, method: str = None): + super().__init__(expr_iter) + self.method = method + + def __iter__(self): + for i in self._iter: + if not isinstance(i, CallMethod): + continue + if self.method is None or i.method == self.method: + yield i + + class TracedModule(Module): """ `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be @@ -312,10 +529,12 @@ class TracedModule(Module): ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf ) assert treedef in self.m_node.argdef_graph_map - inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))] + inputs = filter( + lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs + ) # allow TracedModuleBuilder for retrace. outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs) - if len(outputs) == 1: - return outputs[0] + out_def = self.m_node.argdef_outdef_map[treedef] + outputs = out_def.unflatten(outputs) return outputs @property @@ -339,9 +558,8 @@ class TracedModule(Module): if graph is None: assert not isinstance(module, TracedModule) const = Constant(module) - modulenode = const.outputs[0] - modulenode.module_type = type(module) - call.inputs[0] = modulenode + const.outputs[0] = call.inputs[0] + const.outputs[0].expr = const return [const, call] exprs = [] for expr in graph._exprs: @@ -350,30 +568,41 @@ class TracedModule(Module): if call and inp in graph._inputs: inp_idx = graph._inputs.index(inp) expr.inputs[idx] = call.inputs[inp_idx] + call.inputs[inp_idx].users.append(expr) # replace outputs for submodule's expr for idx, outp in enumerate(expr.outputs): if call and outp in graph._outputs: oup_idx = graph._outputs.index(outp) expr.outputs[idx] = call.outputs[oup_idx] + call.outputs[oup_idx].expr = expr if isinstance(expr, GetAttr): # replace GetAttr with Constant if isinstance(expr.outputs[0], TensorNode): const = Constant(getattr(module, expr.name)) const.outputs = expr.outputs + const.outputs[0].expr = const exprs.append(const) elif isinstance(expr, CallMethod): obj_node = expr.inputs[0] if isinstance(obj_node, ModuleNode): - assert isinstance(expr.inputs[0].expr, GetAttr) - (obj,) = expr.inputs[0].expr.interpret(module) - exprs.extend(_flatten_subgraph(expr.graph, obj, expr)) + pre_expr = expr.inputs[0].expr + if isinstance(pre_expr, GetAttr): + (obj,) = expr.inputs[0].expr.interpret(module) + exprs.extend(_flatten_subgraph(expr.graph, obj, expr)) + else: + # module has been replaced. + assert isinstance(pre_expr, Constant) else: exprs.append(expr) else: exprs.append(expr) + if call is not None: + for i in call.inputs: + i.users.remove(call) + return exprs new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) @@ -422,22 +651,26 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: """ assert active_module_tracer() is None try: + use_sym_shape = set_symbolic_shape(True) set_module_tracing() set_active_module_tracer(module_tracer(_wrapped_function)) + with active_module_tracer().patcher: global_scope = InternalGraph() active_module_tracer().push_scope(global_scope) builder = TracedModuleBuilder(mod, True) NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) - inputs, _ = tree_flatten((args, kwargs)) + inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf) for _, i in enumerate(inputs): - NodeMixin.wrap_safe( - i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) - ) + if isinstance(i, RawTensor): + NodeMixin.wrap_safe( + i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) + ) builder(*args, **kwargs) active_module_tracer().pop_scope() return builder.build() finally: + set_symbolic_shape(use_sym_shape) set_active_module_tracer(None) unset_module_tracing() diff --git a/imperative/python/test/unit/traced_module/test_haoruitao.py b/imperative/python/test/unit/traced_module/test_haoruitao.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf8cd4a131f4e25c8a0b35b8268386834a582fc --- /dev/null +++ b/imperative/python/test/unit/traced_module/test_haoruitao.py @@ -0,0 +1,90 @@ +import io +import pickle + +import numpy as np + +import megengine.functional as F +import megengine.module as M +import megengine.utils.comp_graph_tools as cgtools +from megengine.core._trace_option import set_symbolic_shape +from megengine.experimental.traced_module import trace_module +from megengine.jit import trace + +set_symbolic_shape(True) + + +class Main(M.Module): + def forward(self, x): + return x + + +class PreProcess(M.Module): + def __init__(self): + super().__init__() + self.I = F.ones((1,)) + self.M = F.zeros((1,)) + + def forward(self, data, idx, roi): + N, H, W, C = data.shape + xmax = roi[:, 1, 0] + xmin = roi[:, 0, 0] + ymax = roi[:, 1, 1] + ymin = roi[:, 0, 1] + scale = F.maximum((xmax - xmin) / W, (ymax - ymin) / H) + I = F.broadcast_to(self.I, (N,)) + M = F.broadcast_to(self.M, (N, 3, 3)) + M[:, 0, 0] = scale + M[:, 0, 2] = xmin + M[:, 1, 1] = scale + M[:, 1, 2] = ymin + M[:, 2, 2] = I + resized = ( + F.warp_perspective( + data, M, (H, W), mat_idx=idx, border_mode="CONSTANT", format="NHWC" + ) + .transpose(0, 3, 1, 2) + .astype(np.float32) + ) + return resized + + +class Net(M.Module): + def __init__(self, traced_module): + super().__init__() + self.pre_process = PreProcess() + self.traced_module = traced_module + + def forward(self, data, idx, roi): + x = self.pre_process(data, idx, roi) + x = self.traced_module(x) + return x + + +def test_preprocess(): + module = Main() + data = F.ones((1, 14, 8, 8), dtype=np.uint8) + traced_module = trace_module(module, data) + obj = pickle.dumps(traced_module) + traced_module = pickle.loads(obj) + module = Net(traced_module) + module.eval() + idx = F.zeros((1,), dtype=np.int32) + roi = F.ones((1, 2, 2), dtype=np.float32) + y = module(data, idx, roi) + traced_module = trace_module(module, data, idx, roi) + np.testing.assert_array_equal(traced_module(data, idx, roi), y) + func = trace(traced_module, capture_as_const=True) + np.testing.assert_array_equal(func(data, idx, roi), y) + model = io.BytesIO() + func.dump(model, arg_names=("data", "idx", "roi")) + model.seek(0) + infer_cg = cgtools.GraphInference(model) + np.testing.assert_allclose( + list( + infer_cg.run( + inp_dict={"data": data.numpy(), "idx": idx.numpy(), "roi": roi.numpy()} + ).values() + )[0], + y, + atol=1e-6, + ) diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py new file mode 100644 index 0000000000000000000000000000000000000000..692fbb0b1a704af3c50252c6de46cd3d6d077c97 --- /dev/null +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -0,0 +1,113 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np + +import megengine.functional as F +import megengine.module as M +from megengine.experimental.traced_module import trace_module +from megengine.experimental.traced_module.expr import CallFunction, GetAttr + + +class MyBlock(M.Module): + def __init__(self, in_channels=3, channels=3): + super(MyBlock, self).__init__() + self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) + self.bn1 = M.BatchNorm2d(channels) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + 1 + return x + + +class MyModule(M.Module): + def __init__(self): + super(MyModule, self).__init__() + self.block0 = MyBlock() + self.block1 = MyBlock() + + def forward(self, x): + x = self.block0(x) + x = self.block1(x) + return x + + +def _init_cls(cls): + module = cls() + x = F.ones((1, 3, 3, 3)) + y = module(x) + traced_module = trace_module(module, x) + return traced_module, x, y + + +def _init_block(): + return _init_cls(MyBlock) + + +def _init_module(): + return _init_cls(MyModule) + + +def test_search(): + traced_module, *_ = _init_block() + graph = traced_module.graph + relu_expr = graph.get_call_function(F.relu).as_unique() + assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu + + +def test_insert(): + traced_module, x, expect = _init_block() + graph = traced_module.graph + relu_node = graph.get_call_function(F.relu).as_unique().outputs + neg_node = graph.insert_call_function(F.neg, relu_node) + graph.replace_node({relu_node[0]: neg_node[0]}) + graph.compile() + np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) + + +def test_delete(): + traced_module, x, expect = _init_block() + graph = traced_module.graph + relu_expr = graph.get_call_function(F.relu).as_unique() + node = relu_expr.outputs + repl_node = relu_expr.inputs + graph.replace_node({node[0]: repl_node[0]}) + graph.compile() + np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6) + + +def test_flatten(): + traced_module, x, expect = _init_module() + traced_module = traced_module.flatten() + traced_module.graph.compile() + assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs) + assert len(traced_module.graph._exprs) == 12 + + +def test_extra_block(): + class PostProcess(M.Module): + def forward(self, x): + return x * 2 + + class Net(M.Module): + def __init__(self, traced_module): + super().__init__() + self.post_process = PostProcess() + self.traced_module = traced_module + + def forward(self, x): + x = self.traced_module(x) + x = self.post_process(x) + return x + + traced_module, x, expect = _init_block() + module = Net(traced_module) + np.testing.assert_allclose(2 * expect, module(x), atol=1e-6) + traced_module = trace_module(module, x) + np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6) diff --git a/imperative/python/test/unit/traced_module/test_serialization.py b/imperative/python/test/unit/traced_module/test_serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..be4d2ff00f3875fb766eb68d2101a6a013f9ffd2 --- /dev/null +++ b/imperative/python/test/unit/traced_module/test_serialization.py @@ -0,0 +1,52 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import pickle + +import numpy as np + +import megengine.functional as F +import megengine.module as M +from megengine import Tensor +from megengine.experimental.traced_module import trace_module +from megengine.module import Module + + +class MyBlock(Module): + def __init__(self, in_channels, channels): + super(MyBlock, self).__init__() + self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) + self.bn1 = M.BatchNorm2d(channels) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + 1 + return x + + +class MyModule(Module): + def __init__(self): + super(MyModule, self).__init__() + self.block0 = MyBlock(8, 4) + self.block1 = MyBlock(4, 2) + + def forward(self, x): + x = self.block0(x) + x = self.block1(x) + return x + + +def test_dump_and_load(): + module = MyModule() + x = Tensor(np.ones((1, 8, 14, 14))) + expect = module(x) + traced_module = trace_module(module, x) + np.testing.assert_array_equal(expect, traced_module(x)) + obj = pickle.dumps(traced_module) + pickle.loads(obj) + np.testing.assert_array_equal(expect, traced_module(x)) diff --git a/imperative/python/test/unit/traced_module/test_trace_module.py b/imperative/python/test/unit/traced_module/test_trace_module.py new file mode 100644 index 0000000000000000000000000000000000000000..d4475ecd42b48e2139ef42fed4d00b97bbbbb85f --- /dev/null +++ b/imperative/python/test/unit/traced_module/test_trace_module.py @@ -0,0 +1,42 @@ +import numpy as np + +from megengine import Tensor +from megengine.experimental.traced_module import trace_module +from megengine.module import Module as M + + +class MyModule1(M): + def forward(self, x): + y = Tensor(x) + y += 1 + x = x + 2 + return x, y + + +class MyModule2(M): + def forward(self, x): + y = Tensor([1, x, 1]) + y += 1 + x = x + 2 + return x, y + + +def test_trace_module(): + + x = Tensor(1) + m1 = MyModule1() + tm1 = trace_module(m1, x) + + m2 = MyModule2() + tm2 = trace_module(m2, x) + inp = Tensor(2) + gt = m1(inp) + output = tm1(inp) + for a, b in zip(output, gt): + np.testing.assert_equal(a.numpy(), b.numpy()) + + gt1 = m2(inp) + output1 = tm2(inp) + + for a, b in zip(output1, gt1): + np.testing.assert_equal(a.numpy(), b.numpy()) diff --git a/imperative/python/test/unit/traced_module/test_wujianan.py b/imperative/python/test/unit/traced_module/test_wujianan.py new file mode 100644 index 0000000000000000000000000000000000000000..44474f6daeafa140d90bed63399a978d3cb757f3 --- /dev/null +++ b/imperative/python/test/unit/traced_module/test_wujianan.py @@ -0,0 +1,94 @@ +import io +import pickle + +import numpy as np + +import megengine as mge +import megengine.functional as F +import megengine.module as M +import megengine.utils.comp_graph_tools as cgtools +from megengine.core._trace_option import set_symbolic_shape +from megengine.experimental.traced_module import trace_module +from megengine.jit import trace + +set_symbolic_shape(True) + + +class Main(M.Module): + def forward(self, x): + return x["data"] + + +class PreProcess(M.Module): + def __init__(self): + super().__init__() + self.A = F.zeros((1,)) + self.I = F.ones((1,)) + self.bb_out = mge.tensor( + np.array([[[0, 0], [160, 0], [160, 48], [0, 48]]], dtype="float32") + ) + + def forward(self, data, quad): + """ + data: (1, 3, 48, 160) + quad: (1, 4, 2) + """ + N = quad.shape[0] + dst = F.repeat(self.bb_out, N, axis=0).reshape(-1, 4, 2) + I = F.broadcast_to(self.I, quad.shape) + A = F.broadcast_to(self.A, (N, 8, 8)) + A[:, 0:4, 0:2] = quad + A[:, 4:8, 5:6] = I[:, :, 0:1] + A[:, 0:4, 6:8] = -quad * dst[:, :, 0:1] + A[:, 4:8, 3:5] = quad + A[:, 0:4, 2:3] = I[:, :, 0:1] + A[:, 4:8, 6:8] = -quad * dst[:, :, 1:2] + B = dst.transpose(0, 2, 1).reshape(-1, 8, 1) + M = F.concat([F.matmul(F.matinv(A), B)[:, :, 0], I[:, 0:1, 0]], axis=1).reshape( + -1, 3, 3 + ) + new_data = F.warp_perspective(data, M, (48, 160)) # (N, 3, 48, 160) + return {"data": new_data} + + +class Net(M.Module): + def __init__(self, traced_module): + super().__init__() + self.pre_process = PreProcess() + self.traced_module = traced_module + + def forward(self, data, quad): + x = self.pre_process(data, quad) + x = self.traced_module(x) + return x + + +def test_preprocess(): + batch_size = 2 + module = Main() + data = mge.tensor( + np.random.randint(0, 256, size=(batch_size, 3, 48, 160)), dtype=np.float32 + ) + traced_module = trace_module(module, {"data": data}) + obj = pickle.dumps(traced_module) + traced_module = pickle.loads(obj) + module = Net(traced_module) + module.eval() + quad = mge.tensor(np.random.normal(size=(batch_size, 4, 2)), dtype=np.float32) + expect = module(data, quad) + traced_module = trace_module(module, data, quad) + actual = traced_module(data, quad) + for i, j in zip(expect, actual): + np.testing.assert_array_equal(i, j) + func = trace(traced_module, capture_as_const=True) + actual = func(data, quad) + for i, j in zip(expect, actual): + np.testing.assert_array_equal(i, j) + model = io.BytesIO() + func.dump(model, arg_names=("data", "quad")) + model.seek(0) + infer_cg = cgtools.GraphInference(model) + actual = list( + infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values() + )[0] + np.testing.assert_allclose(expect, actual)