提交 04b14241 编写于 作者: M Megvii Engine Team

fix(traced_module): fix TracedModule InternalGraph deepcopy exceed max recursion limit

GitOrigin-RevId: 2b52ad913d40e8bb54c9dcadc559a35b3b39099a
上级 355782ae
......@@ -148,7 +148,6 @@ class module_tracer:
self.checker.push_scope()
self._activate_constant_cache.append([])
def pop_scope(self):
self._active_scopes.pop()
self.checker.pop_scope()
......@@ -157,7 +156,6 @@ class module_tracer:
if hasattr(obj, "_NodeMixin__node"):
delattr(obj, "_NodeMixin__node")
def current_scope(self):
if self._active_scopes:
return self._active_scopes[-1]
......
......@@ -58,6 +58,7 @@ from ..quantization.observer import (
SyncMinMaxObserver,
)
from ..tensor import Tensor
from ..utils.max_recursion_limit import max_recursion_limit
from ..version import __version__
from .expr import (
Apply,
......@@ -1247,17 +1248,18 @@ class InternalGraph:
return result
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
with max_recursion_limit():
if id(self) in memo:
return memo[id(self)]
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
def _get_meth_name(obj, func):
......@@ -2359,16 +2361,17 @@ class TracedModule(Module):
return result
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
result._update_ref()
return result
with max_recursion_limit():
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
result._update_ref()
return result
def cpp_apply_module_trace(opdef, *args):
......
......@@ -6,7 +6,7 @@ import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.module.module import Module
from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
from megengine.traced_module.expr import CallFunction
......@@ -58,7 +58,7 @@ class MyModule4(M.Module):
def test_trace_module():
enable_expr_checker()
x = Tensor(1)
m1 = MyModule1()
tm1 = trace_module(m1, x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册