diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index bfe4ea41cb221fc007662825fe73eeda0e1e9fe5..70a020f4863b5ccb70b7407666e753ac1378099e 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -148,7 +148,6 @@ class module_tracer: self.checker.push_scope() self._activate_constant_cache.append([]) - def pop_scope(self): self._active_scopes.pop() self.checker.pop_scope() @@ -157,7 +156,6 @@ class module_tracer: if hasattr(obj, "_NodeMixin__node"): delattr(obj, "_NodeMixin__node") - def current_scope(self): if self._active_scopes: return self._active_scopes[-1] diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 4950a4535020ebc58ce621279637244b0c39c348..670ab7e90306ae2fbe5a613e606a91abe6969398 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -58,6 +58,7 @@ from ..quantization.observer import ( SyncMinMaxObserver, ) from ..tensor import Tensor +from ..utils.max_recursion_limit import max_recursion_limit from ..version import __version__ from .expr import ( Apply, @@ -1247,17 +1248,18 @@ class InternalGraph: return result def __deepcopy__(self, memo): - if id(self) in memo: - return memo[id(self)] - cls = self.__class__ - result = cls.__new__(cls) - state = {} - memo[id(self)] = result - for k, v in self.__dict__.items(): - if not isinstance(v, weakref.ReferenceType): - state[k] = copy.deepcopy(v, memo) - result.__dict__.update(state) - return result + with max_recursion_limit(): + if id(self) in memo: + return memo[id(self)] + cls = self.__class__ + result = cls.__new__(cls) + state = {} + memo[id(self)] = result + for k, v in self.__dict__.items(): + if not isinstance(v, weakref.ReferenceType): + state[k] = copy.deepcopy(v, memo) + result.__dict__.update(state) + return result def _get_meth_name(obj, func): @@ -2359,16 +2361,17 @@ class TracedModule(Module): return result def __deepcopy__(self, memo): - cls = self.__class__ - result = cls.__new__(cls) - state = {} - memo[id(self)] = result - for k, v in self.__dict__.items(): - if not isinstance(v, weakref.ReferenceType): - state[k] = copy.deepcopy(v, memo) - result.__dict__.update(state) - result._update_ref() - return result + with max_recursion_limit(): + cls = self.__class__ + result = cls.__new__(cls) + state = {} + memo[id(self)] = result + for k, v in self.__dict__.items(): + if not isinstance(v, weakref.ReferenceType): + state[k] = copy.deepcopy(v, memo) + result.__dict__.update(state) + result._update_ref() + return result def cpp_apply_module_trace(opdef, *args): diff --git a/imperative/python/test/unit/traced_module/test_trace_module.py b/imperative/python/test/unit/traced_module/test_trace_module.py index e4441c49f372cf4e8c3e4073874322f3d7349026..d3baf1530347ac01cabbf861afdf1b9204379507 100644 --- a/imperative/python/test/unit/traced_module/test_trace_module.py +++ b/imperative/python/test/unit/traced_module/test_trace_module.py @@ -6,7 +6,7 @@ import megengine.functional as F import megengine.module as M from megengine import Tensor from megengine.module.module import Module -from megengine.traced_module import TracedModule, trace_module +from megengine.traced_module import TracedModule, enable_expr_checker, trace_module from megengine.traced_module.expr import CallFunction @@ -58,7 +58,7 @@ class MyModule4(M.Module): def test_trace_module(): - + enable_expr_checker() x = Tensor(1) m1 = MyModule1() tm1 = trace_module(m1, x)