From 355782aecb954fe015c096d592efac7edc2a7bdb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 11 Jan 2022 13:47:45 +0800 Subject: [PATCH] fix(traced_module): clear node after trace module GitOrigin-RevId: f7f602403481fdeb6a77435bc98c5d9e7a5fa58e --- imperative/python/megengine/traced_module/expr.py | 1 + .../python/megengine/traced_module/module_tracer.py | 13 +++++++++++++ imperative/python/megengine/traced_module/node.py | 5 +++++ .../python/megengine/traced_module/traced_module.py | 9 ++++++++- 4 files changed, 27 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index 759ac1c8e..c22249fc4 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -763,6 +763,7 @@ class Constant(Expr): current_graph = active_module_tracer().current_scope() current_graph._namespace.auto_naming_for_outputs(expr) current_graph._insert(expr) + active_module_tracer().current_constant_cache().append(expr.value) return expr.outputs[0] def interpret(self, *inputs): diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index 7bd8ab419..bfe4ea41c 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -131,6 +131,7 @@ class module_tracer: self._active_scopes = [] self.checker = TracedModuleChecker(self) self.patcher = Patcher(wrap_fn) + self._activate_constant_cache = [] @classmethod def register_as_builtin(cls, mod): @@ -145,16 +146,28 @@ class module_tracer: def push_scope(self, scope): self._active_scopes.append(scope) self.checker.push_scope() + self._activate_constant_cache.append([]) + def pop_scope(self): self._active_scopes.pop() self.checker.pop_scope() + cache = self._activate_constant_cache.pop() + for obj in cache: + if hasattr(obj, "_NodeMixin__node"): + delattr(obj, "_NodeMixin__node") + def current_scope(self): if self._active_scopes: return self._active_scopes[-1] return None + def current_constant_cache(self): + if self._activate_constant_cache: + return self._activate_constant_cache[-1] + return None + def top_scope(self): if self._active_scopes: return self._active_scopes[0] diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index 364ab8f7a..079ee46ec 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -379,6 +379,11 @@ class NodeMixin(abc.ABC): if isinstance(value, NodeMixin): value._record_wrapped_nodes(node) + @classmethod + def clear_node(cls, value): + if hasattr(value, "_NodeMixin__node"): + delattr(value, "_NodeMixin__node") + @classmethod def get(cls, value, *default): return getattr(value, "_NodeMixin__node", *default) diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index c4c25094c..4950a4535 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -1980,7 +1980,10 @@ class TracedModule(Module): assert ( treedef in self.argdef_graph_map ), "support input args kwargs format: \n{}, but get: \n{}".format( - "\n ".join("forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys()), + "\n ".join( + "forward({})".format(i._args_kwargs_repr()) + for i in self.argdef_graph_map.keys() + ), treedef._args_kwargs_repr(), ) inputs = filter( @@ -2514,3 +2517,7 @@ def trace_module( set_symbolic_shape(use_sym_shape) set_active_module_tracer(None) unset_module_tracing() + for t in mod.tensors(recursive=True): + NodeMixin.clear_node(t) + for t in inputs: + NodeMixin.clear_node(t) -- GitLab