diff --git a/imperative/python/megengine/traced_module/__init__.py b/imperative/python/megengine/traced_module/__init__.py index 848c968e2cea6b97435dce0e7480ed14cf2f7345..6bbdc6689403dc34ee58825ddc324703b2b143b3 100644 --- a/imperative/python/megengine/traced_module/__init__.py +++ b/imperative/python/megengine/traced_module/__init__.py @@ -10,6 +10,7 @@ from ..core._imperative_rt.core2 import set_cpp_apply_module_trace from . import compat from ._passes import optimize from .pytree import register_supported_type +from .tm_config import disable_default_checker, enable_expr_checker from .traced_module import ( TracedModule, _register_all_builtin_module, @@ -29,4 +30,6 @@ __all__ = [ "wrap", "TracedModule", "optimize", + "enable_expr_checker", + "disable_default_checker", ] diff --git a/imperative/python/megengine/traced_module/checker.py b/imperative/python/megengine/traced_module/checker.py new file mode 100644 index 0000000000000000000000000000000000000000..31fa0470b5daf45ece7f76f830e586b34e70869d --- /dev/null +++ b/imperative/python/megengine/traced_module/checker.py @@ -0,0 +1,142 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import traceback +from typing import Sequence + +import numpy as np + +from ..core._imperative_rt.core2 import apply +from ..core._imperative_rt.ops import ROIAlign, ROIPooling +from ..core.ops.builtin import Copy +from ..core.tensor.utils import isscalar, setscalar +from ..tensor import Tensor +from .tm_config import _exclude_from_trace + + +class TracedModuleChecker: + def __init__(self, tracer): + self._active_node2values = [] + self.tracer = tracer + + self.node_without_tensor_info = {} + + def push_scope(self): + self._active_node2values.append({}) + + def pop_scope(self): + self._active_node2values.pop() + + def current_node2values(self): + return self._active_node2values[-1] + + def reset_checker(self): + self._active_node2values = [] + + def check_node_not_in_scope(self): + if self.node_without_tensor_info: + for node, info in self.node_without_tensor_info.items(): + for expr in info[0]._exprs: + if node in expr.inputs or node in expr.outputs: + traceback.print_list(info[1]) + raise ValueError( + "node({}) not in the graph:\n{}".format(node, info[0]) + ) + return True + else: + return False + + def check_net_outputs(self, tm_res, gt_res): + if isinstance(tm_res, Tensor): + np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy()) + elif isinstance(tm_res, Sequence): + for i, j in zip(tm_res, gt_res): + np.testing.assert_allclose(i.numpy(), j.numpy()) + else: + for k in tm_res.__dict__.keys(): + np.testing.assert_allclose( + getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy() + ) + + def record_nodemixin(self, node, value): + self.current_node2values()[node] = value + + def record_node2value(self, node, value): + with _exclude_from_trace(): + self.current_node2values()[node] = apply( + Copy(comp_node=value.device), value + )[0] + if isscalar(value): + setscalar(self.current_node2values()[node]) + + def check_apply_special_cases(self, opdef, num_outputs): + indexs = list(range(num_outputs)) + if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE: + indexs.pop(-1) + if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE: + indexs.pop(-1) + return indexs + + def check_expr_results(self, expr_outputs, gt_outputs, indexs=None): + expr_outputs = ( + (expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs + ) + gt_outputs = ( + (gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs + ) + if indexs is not None: + for i in indexs: + np.testing.assert_allclose( + expr_outputs[i].numpy(), gt_outputs[i].numpy() + ) + else: + np.testing.assert_allclose(expr_outputs, gt_outputs) + + def get_node2value(self, inputs, start_idx=0): + inp_values = [] + has_node_not_in_scope = False + for i in range(start_idx, len(inputs)): + try: + inp_values.append(self.current_node2values()[inputs[i]]) + except: + has_node_not_in_scope = True + self.node_without_tensor_info[inputs[i]] = [ + self.tracer.current_scope(), + traceback.extract_stack(), + ] + return inp_values, has_node_not_in_scope + + def check_expr_interpret(self, expr, gt_outputs): + ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs) + if not has_node_not_in_scope: + expr_res = expr.interpret(*ori_in) + try: + self.check_expr_results(expr_res, gt_outputs) + except: + raise ValueError("Error occurred when checking expr: {}".format(expr)) + + def check_apply(self, expr, gt_outputs, opdef): + ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs) + if not has_node_not_in_scope: + expr_res = expr.interpret(*ori_in) + indexs = self.check_apply_special_cases(opdef, len(gt_outputs)) + try: + self.check_expr_results(expr_res, gt_outputs, indexs=indexs) + except: + raise ValueError("Error occurred when checking expr: {}".format(expr)) + + def check_builtin_module(self, module, expr, gt_outputs): + ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1) + if not has_node_not_in_scope: + ori_in.insert(0, module) + expr_res = expr.interpret(*ori_in) + try: + self.check_expr_results(expr_res, gt_outputs) + except: + raise ValueError( + "{}, Error occurred when checking expr: {}".format(expr) + ) diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index b7e7c077e76052bdaddf6c6a67a4c5c5feae1f65..759ac1c8e8eca110b5bee7bd9d73961e783260ce 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -32,6 +32,7 @@ from .module_tracer import active_module_tracer, module_tracer from .node import ModuleNode, Node, NodeMixin, TensorNode from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten from .serialization import _ModuleState +from .tm_config import _exclude_from_trace, _get_expr_checker from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args @@ -611,6 +612,8 @@ class Apply(Expr): inp_nodes = [NodeMixin.get(inputs[0])] for i in inputs[1:]: node = Constant.make(i) + if _get_expr_checker(): + active_module_tracer().checker.record_node2value(node, Tensor(i)) inp_nodes.append(node) apply_node = cls.make(opdef) for n in inp_nodes: @@ -624,11 +627,17 @@ class Apply(Expr): unset_module_tracing() outputs = apply(opdef, *inputs) + outputs = list(map(Tensor, outputs)) set_module_tracing() apply_node.add_outputs(outputs) for n, v in zip(apply_node.outputs, outputs): NodeMixin.wrap_safe(v, n) + + if _get_expr_checker(): + with _exclude_from_trace(): + active_module_tracer().checker.check_apply(apply_node, outputs, opdef) + return list(outputs) diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index 4cba9a35d7924872fa7bbc8e7568700f2dc8235d..7bd8ab419fd0f21e8b4588fddab0fe404622b594 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -12,6 +12,7 @@ from .. import functional as F from ..core.tensor.array_method import ArrayMethodMixin from ..module import Module from ..module.qat import QATModule +from .checker import TracedModuleChecker _active_module_tracer = None @@ -128,6 +129,7 @@ class module_tracer: def __init__(self, wrap_fn): self._active_scopes = [] + self.checker = TracedModuleChecker(self) self.patcher = Patcher(wrap_fn) @classmethod @@ -142,9 +144,11 @@ class module_tracer: def push_scope(self, scope): self._active_scopes.append(scope) + self.checker.push_scope() def pop_scope(self): self._active_scopes.pop() + self.checker.pop_scope() def current_scope(self): if self._active_scopes: diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index 4bfff4afc8eb3e45466884d88c42d1fd8a0cc669..364ab8f7a3d4111f46f2e3ba3ec1f09bf14525f3 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -18,6 +18,8 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor from ..module import Module from ..quantization.utils import QParams from ..tensor import Tensor +from .module_tracer import active_module_tracer +from .tm_config import _get_expr_checker from .utils import _check_obj_attr logger = get_logger(__name__) @@ -343,6 +345,11 @@ class NodeMixin(abc.ABC): if isinstance(value, NodeMixin): value._record_wrapped_nodes(node) setattr(value, "_NodeMixin__node", node) + if _get_expr_checker(): + if isinstance(value, RawTensor): + active_module_tracer().checker.record_node2value(node, value) + if isinstance(value, NodeMixin): + active_module_tracer().checker.record_nodemixin(node, value) else: assert callable(node) n = node() @@ -352,6 +359,11 @@ class NodeMixin(abc.ABC): if isinstance(value, NodeMixin): value._record_wrapped_nodes(n) setattr(value, "_NodeMixin__node", n) + if _get_expr_checker(): + if isinstance(value, RawTensor): + active_module_tracer().checker.record_node2value(n, value) + if isinstance(value, NodeMixin): + active_module_tracer().checker.record_nodemixin(n, value) @classmethod def wrap_safe(cls, value, node): @@ -359,6 +371,11 @@ class NodeMixin(abc.ABC): if isinstance(value, RawTensor): cls._record_tensornode_property(node, value) setattr(value, "_NodeMixin__node", node) + if _get_expr_checker(): + if isinstance(value, RawTensor): + active_module_tracer().checker.record_node2value(node, value) + if isinstance(value, NodeMixin): + active_module_tracer().checker.record_nodemixin(node, value) if isinstance(value, NodeMixin): value._record_wrapped_nodes(node) diff --git a/imperative/python/megengine/traced_module/pytree.py b/imperative/python/megengine/traced_module/pytree.py index 4a9f1e0b88788590ebf2079e44c1b75a46700bad..c55744e3fd2d9e5668c6c5b54854ca7dabca3a27 100644 --- a/imperative/python/megengine/traced_module/pytree.py +++ b/imperative/python/megengine/traced_module/pytree.py @@ -212,7 +212,11 @@ def tree_flatten( to reconstruct the pytree. """ if type(values) not in SUPPORTED_TYPE: - assert is_leaf(values), values + assert is_leaf( + values + ), 'doesn\'t support {} type, MUST use "register_supported_type" method to register self-defined type'.format( + values + ) node = LeafDef(leaf_type(values)) if is_const_leaf(values): node.const_val = values diff --git a/imperative/python/megengine/traced_module/tm_config.py b/imperative/python/megengine/traced_module/tm_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6453a05ed00e1d5cb79e6166fa2a78b7dc8a4caf --- /dev/null +++ b/imperative/python/megengine/traced_module/tm_config.py @@ -0,0 +1,55 @@ +import contextlib + +from ..core._imperative_rt.core2 import ( + is_tracing_module, + set_module_tracing, + unset_module_tracing, +) + +_enable_expr_checker = False +_enable_default_checker = True + + +def _get_expr_checker(): + return _enable_expr_checker + + +def _get_default_checker(): + return _enable_default_checker + + +def enable_expr_checker(): + r"""Call this function to check the result of each expr during tracing.""" + global _enable_expr_checker + _enable_expr_checker = True + _enable_default_checker = False + + +def disable_default_checker(): + r"""Call this function to disable checking the final output of the model after tracing.""" + global _enable_default_checker + _enable_default_checker = False + + +_enable_graph_surgery_mode = False + + +def _graph_surgery_mode(): + return _enable_graph_surgery_mode + + +def _set_graph_surgery_mode(mode: bool): + global _enable_graph_surgery_mode + pre_mode = _enable_graph_surgery_mode + _enable_graph_surgery_mode = mode + return pre_mode + + +@contextlib.contextmanager +def _exclude_from_trace(): + is_tracing = is_tracing_module() + if is_tracing: + unset_module_tracing() + yield + if is_tracing: + set_module_tracing() diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 0c860f93a5fbb736c7f5b564a8c25ce4cd26fd80..e22c818578a3f549555e10df503365214728e4f4 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -36,11 +36,14 @@ from .. import get_logger from .. import module as M from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import ( + apply, is_tracing_module, set_module_tracing, unset_module_tracing, ) from ..core._trace_option import set_symbolic_shape +from ..core.ops.builtin import Copy +from ..core.tensor.utils import isscalar, setscalar from ..module import Module from ..module import external as MExternal from ..module.qat import QATModule @@ -98,6 +101,13 @@ from .serialization import ( load_call_tensor_method_expr, load_functional, ) +from .tm_config import ( + _exclude_from_trace, + _get_default_checker, + _get_expr_checker, + _graph_surgery_mode, + _set_graph_surgery_mode, +) from .utils import ( _check_builtin_module_attr, _check_obj_attr, @@ -117,26 +127,14 @@ def _is_builtin_name(name: str) -> bool: def _is_leaf(node): - assert isinstance(node, RawTensor), "doesn't support {} in return values".format( + assert isinstance( + node, RawTensor + ), 'doesn\'t support {} in return values, MUST use Tensor or use "register_supported_type" method to register self-defined type'.format( type(node) ) return isinstance(node, RawTensor) -_enable_graph_surgery_mode = False - - -def _graph_surgery_mode(): - return _enable_graph_surgery_mode - - -def _set_graph_surgery_mode(mode: bool): - global _enable_graph_surgery_mode - pre_mode = _enable_graph_surgery_mode - _enable_graph_surgery_mode = mode - return pre_mode - - def _node_to_tensor(*args, **kwargs): tensors = [] nodes, tree_def = tree_flatten((args, kwargs)) @@ -1295,7 +1293,12 @@ def _wrapped_function(orig_func): return orig_func(*args, **kwargs) if isinstance(args[1], RawTensor): node = NodeMixin.get(inputs[1]) - inputs[1] = copy.copy(inputs[1]) + is_scalar = isscalar(inputs[1]) + inputs[1] = apply( + Copy(comp_node=inputs[1].device), Tensor(inputs[1]) + )[0] + if is_scalar: + setscalar(inputs[1]) # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, # which will cause they have same _NodeMixin__node in tracing. NodeMixin.wrap_safe(inputs[1], node) @@ -1319,6 +1322,13 @@ def _wrapped_function(orig_func): else: outputs = None call_node.add_outputs(outputs) + + if _get_expr_checker(): + with _exclude_from_trace(): + active_module_tracer().checker.check_expr_interpret( + call_node, outputs + ) + set_module_tracing() return rst return orig_func(*args, **kwargs) @@ -1500,6 +1510,12 @@ class TracedModuleBuilder(NodeMixin): unset_module_tracing() rst = self._mod(*args, **kwargs) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) + if _get_expr_checker(): + with _exclude_from_trace(): + tmp = self.build() + active_module_tracer().checker.check_builtin_module( + tmp, callnode, outputs + ) set_module_tracing() if self._is_builtin: self._body = None @@ -1674,7 +1690,9 @@ class TracedModuleBuilder(NodeMixin): if not isinstance(mod_attr, (List, Dict, QATModule)): assert mod_attr is wrapped._mod else: - assert mod_attr is wrapped + assert ( + mod_attr is wrapped + ), "TracedModule do not support modify attributes, please check your code." if isinstance(wrapped, (NodeMixin, RawTensor)): NodeMixin.wrap( @@ -2469,11 +2487,23 @@ def trace_module( qualname="{}.[{}]".format(net_name, "arg_{}".format(_)), ), ) - builder(*args, **kwargs) + rst = builder(*copy.deepcopy(args), **copy.deepcopy(kwargs)) active_module_tracer().pop_scope() traced_mod = builder.build() traced_mod.argspec = forward_argspec traced_mod.graph._reset_ids() + + has_expr_not_check = False + if _get_expr_checker(): + has_expr_not_check = ( + active_module_tracer().checker.check_node_not_in_scope() + ) + if _get_default_checker() or has_expr_not_check: + with _exclude_from_trace(): + tm_res = traced_mod(*args, **kwargs) + tm_res, _ = tree_flatten(tm_res, is_leaf=_is_leaf) + rst, _ = tree_flatten(rst, is_leaf=_is_leaf) + active_module_tracer().checker.check_net_outputs(tm_res, rst) return traced_mod finally: set_symbolic_shape(use_sym_shape) diff --git a/imperative/python/megengine/traced_module/utils.py b/imperative/python/megengine/traced_module/utils.py index b722fc3067b167370d4d0692f39c0c458b714bdf..d93b658fc186bfab70e0d27788831e072c598d9f 100644 --- a/imperative/python/megengine/traced_module/utils.py +++ b/imperative/python/megengine/traced_module/utils.py @@ -5,16 +5,15 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import collections import copy import inspect from collections.abc import MutableMapping, MutableSequence from inspect import FullArgSpec -from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union from .. import get_logger from ..module import Module -from ..tensor import Parameter, Tensor +from ..tensor import Tensor logger = get_logger(__name__) diff --git a/imperative/python/test/unit/traced_module/test_qat_module.py b/imperative/python/test/unit/traced_module/test_qat_module.py index 6ef8764b51f79253ede04a704f1aacaaf928d09b..57a9469363c85a46af2b92d67aa2752591bda9d9 100644 --- a/imperative/python/test/unit/traced_module/test_qat_module.py +++ b/imperative/python/test/unit/traced_module/test_qat_module.py @@ -109,6 +109,7 @@ def build_observered_net(net: M.Module, observer_cls): ) Q.enable_observer(qat_net) inp = Tensor(np.random.random(size=(5, 3, 32, 32))) + qat_net.eval() qat_net(inp) Q.disable_observer(qat_net) return qat_net @@ -116,6 +117,7 @@ def build_observered_net(net: M.Module, observer_cls): def build_fakequanted_net(net: QATModule, fakequant_cls): qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls)) + qat_net.eval() return qat_net @@ -162,6 +164,7 @@ def test_load_param(): def _check_module(build_func: Callable): net = build_func() + net.eval() buffer = io.BytesIO() mge.save(net.state_dict(), buffer) buffer.seek(0) @@ -185,6 +188,7 @@ def test_load_param(): def test_qualname(): def _check_qualname(net): inp = Tensor(np.random.random(size=(5, 3, 32, 32))) + net.eval() traced_net = trace_module(net, inp) base_qualname = traced_net.graph.qualname for node in traced_net.graph.nodes():