提交 355782ae 编写于 作者: M Megvii Engine Team

fix(traced_module): clear node after trace module

GitOrigin-RevId: f7f602403481fdeb6a77435bc98c5d9e7a5fa58e
上级 fba54488
......@@ -763,6 +763,7 @@ class Constant(Expr):
current_graph = active_module_tracer().current_scope()
current_graph._namespace.auto_naming_for_outputs(expr)
current_graph._insert(expr)
active_module_tracer().current_constant_cache().append(expr.value)
return expr.outputs[0]
def interpret(self, *inputs):
......
......@@ -131,6 +131,7 @@ class module_tracer:
self._active_scopes = []
self.checker = TracedModuleChecker(self)
self.patcher = Patcher(wrap_fn)
self._activate_constant_cache = []
@classmethod
def register_as_builtin(cls, mod):
......@@ -145,16 +146,28 @@ class module_tracer:
def push_scope(self, scope):
self._active_scopes.append(scope)
self.checker.push_scope()
self._activate_constant_cache.append([])
def pop_scope(self):
self._active_scopes.pop()
self.checker.pop_scope()
cache = self._activate_constant_cache.pop()
for obj in cache:
if hasattr(obj, "_NodeMixin__node"):
delattr(obj, "_NodeMixin__node")
def current_scope(self):
if self._active_scopes:
return self._active_scopes[-1]
return None
def current_constant_cache(self):
if self._activate_constant_cache:
return self._activate_constant_cache[-1]
return None
def top_scope(self):
if self._active_scopes:
return self._active_scopes[0]
......
......@@ -379,6 +379,11 @@ class NodeMixin(abc.ABC):
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
@classmethod
def clear_node(cls, value):
if hasattr(value, "_NodeMixin__node"):
delattr(value, "_NodeMixin__node")
@classmethod
def get(cls, value, *default):
return getattr(value, "_NodeMixin__node", *default)
......
......@@ -1980,7 +1980,10 @@ class TracedModule(Module):
assert (
treedef in self.argdef_graph_map
), "support input args kwargs format: \n{}, but get: \n{}".format(
"\n ".join("forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys()),
"\n ".join(
"forward({})".format(i._args_kwargs_repr())
for i in self.argdef_graph_map.keys()
),
treedef._args_kwargs_repr(),
)
inputs = filter(
......@@ -2514,3 +2517,7 @@ def trace_module(
set_symbolic_shape(use_sym_shape)
set_active_module_tracer(None)
unset_module_tracing()
for t in mod.tensors(recursive=True):
NodeMixin.clear_node(t)
for t in inputs:
NodeMixin.clear_node(t)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册