diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 58f7bf52ff28c63ff94227e087ab8dd8479d5727..10798776000399c99b20c829e633cc929a1ee063 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -145,9 +145,8 @@ def _node_to_tensor(*args, **kwargs): value = n.value if value is None: flag = _set_graph_surgery_mode(False) - unset_module_tracing() - value = F.zeros(shape=n._shape, dtype=n._dtype) - set_module_tracing() + with _exclude_from_trace(): + value = F.zeros(shape=n._shape, dtype=n._dtype) _set_graph_surgery_mode(flag) orig_n = NodeMixin.get(value, None) if orig_n is None or "setitem" not in orig_n._name: @@ -1274,8 +1273,10 @@ def _wrapped_function(orig_func): @functools.wraps(orig_func) def wrapped_fn(*args, **kwargs): method_func = kwargs.pop("method_func", wrapped_fn) - if is_tracing_module(): - unset_module_tracing() + if not is_tracing_module(): + return orig_func(*args, **kwargs) + + with _exclude_from_trace(): inputs, tree_def = tree_flatten((args, kwargs)) for i in inputs: if not NodeMixin.get(i, None): @@ -1290,7 +1291,6 @@ def _wrapped_function(orig_func): if meth_name == "__new__": if all([not isinstance(i, RawTensor) for i in inputs]): # only trace Tensor.__new__() when there are tensors in args - set_module_tracing() return orig_func(*args, **kwargs) if isinstance(args[1], RawTensor): node = NodeMixin.get(inputs[1]) @@ -1327,9 +1327,7 @@ def _wrapped_function(orig_func): call_node, outputs ) - set_module_tracing() return rst - return orig_func(*args, **kwargs) return wrapped_fn @@ -1339,8 +1337,8 @@ class TracedModuleBuilder(NodeMixin): _mod = None # type: Module _body = None # type: InternalGraph _is_builtin = None # type: bool - _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"] - _argdef_outdef_map = None # type: Dict[Treedef, Treedef] + _argdef_graph_map = None # type: Dict[TreeDef, "InternalGraph"] + _argdef_outdef_map = None # type: Dict[TreeDef, TreeDef] nodes = None __builder_attributes__ = [ @@ -1371,9 +1369,8 @@ class TracedModuleBuilder(NodeMixin): else module_tracer.is_builtin(mod) ) if isinstance(self._mod, QATModule): - unset_module_tracing() - self._check_qat_module(self._mod) - set_module_tracing() + with _exclude_from_trace(): + self._check_qat_module(self._mod) self._argdef_graph_map = {} self._argdef_outdef_map = {} @@ -1458,18 +1455,17 @@ class TracedModuleBuilder(NodeMixin): setattr(traced_module, k, v) if isinstance(self._mod, QATModule): - unset_module_tracing() - traced_module.with_act = self._mod.with_act - traced_module.with_weight = self._mod.with_weight - if not hasattr(traced_module, "act_fake_quant"): - traced_module.act_fake_quant = None - if not hasattr(traced_module, "act_observer"): - traced_module.act_observer = None - if not hasattr(traced_module, "weight_fake_quant"): - traced_module.weight_fake_quant = None - if not hasattr(traced_module, "weight_observer"): - traced_module.weight_observer = None - set_module_tracing() + with _exclude_from_trace(): + traced_module.with_act = self._mod.with_act + traced_module.with_weight = self._mod.with_weight + if not hasattr(traced_module, "act_fake_quant"): + traced_module.act_fake_quant = None + if not hasattr(traced_module, "act_observer"): + traced_module.act_observer = None + if not hasattr(traced_module, "weight_fake_quant"): + traced_module.weight_fake_quant = None + if not hasattr(traced_module, "weight_observer"): + traced_module.weight_observer = None if self._is_top: traced_module._update_ref() @@ -1505,16 +1501,14 @@ class TracedModuleBuilder(NodeMixin): callnode.arg_def = tree_def if self._is_builtin or tree_def in self._argdef_graph_map: - unset_module_tracing() - rst = self._mod(*args, **kwargs) - outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) - if _get_expr_checker(): - with _exclude_from_trace(): + with _exclude_from_trace(): + rst = self._mod(*args, **kwargs) + outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) + if _get_expr_checker(): tmp = self.build() active_module_tracer().checker.check_builtin_module( tmp, callnode, outputs ) - set_module_tracing() if self._is_builtin: self._body = None elif tree_def in self._argdef_graph_map: @@ -1640,16 +1634,17 @@ class TracedModuleBuilder(NodeMixin): if isinstance(attr, (List, Dict)): flag = _set_graph_surgery_mode(False) - unset_module_tracing() - has_module, m_container = replace_container_with_module_container(attr) - if m_container: - attr = m_container - if has_module and not m_container: - raise ValueError( - "Can not trace the module that uses the same container to store" - " Module and Non-Module objects." + with _exclude_from_trace(): + has_module, m_container = replace_container_with_module_container( + attr ) - set_module_tracing() + if m_container: + attr = m_container + if has_module and not m_container: + raise ValueError( + "Can not trace the module that uses the same container to store" + " Module and Non-Module objects." + ) _set_graph_surgery_mode(flag) if isinstance(attr, Module): diff --git a/imperative/python/megengine/traced_module/utils.py b/imperative/python/megengine/traced_module/utils.py index 8fe9dd873b83d5a00f25c04b2fbd9a0de2f1c76c..9a262c528de8505cd79806eb14caf0d13f65daa2 100644 --- a/imperative/python/megengine/traced_module/utils.py +++ b/imperative/python/megengine/traced_module/utils.py @@ -5,8 +5,8 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import collections import copy -import contextlib import inspect from collections.abc import MutableMapping, MutableSequence from inspect import FullArgSpec @@ -71,7 +71,9 @@ def _convert_kwargs_to_args( arg_specs_args = arg_specs.args arg_specs_defaults = arg_specs.defaults if arg_specs.defaults else [] arg_specs_kwonlyargs = arg_specs.kwonlyargs - arg_specs_kwonlydefaults = arg_specs.kwonlydefaults if arg_specs.kwonlydefaults else dict() + arg_specs_kwonlydefaults = ( + arg_specs.kwonlydefaults if arg_specs.kwonlydefaults else dict() + ) if is_bounded: arg_specs_args = arg_specs.args[1:] new_args = [] @@ -104,17 +106,17 @@ def _convert_kwargs_to_args( new_kwargs[kwarg_name] = kwargs[kwarg_name] else: if kwarg_name not in arg_specs_kwonlydefaults: - raise TypeError("{} missing required keyword-only argument: {}".format( - func_name, kwarg_name - )) + raise TypeError( + "{} missing required keyword-only argument: {}".format( + func_name, kwarg_name + ) + ) new_kwargs[kwarg_name] = arg_specs_kwonlydefaults[kwarg_name] for k, v in kwargs.items(): if k not in arg_specs_args and k not in arg_specs_kwonlyargs: if arg_specs.varkw is None: raise TypeError( - "{} got an unexpected keyword argument {}".format( - func_name, k - ) + "{} got an unexpected keyword argument {}".format(func_name, k) ) new_kwargs[k] = v return tuple(new_args), new_kwargs