diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index 759ac1c8e8eca110b5bee7bd9d73961e783260ce..c22249fc41d8a2415e1bf202994ea3403428dd87 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -763,6 +763,7 @@ class Constant(Expr): current_graph = active_module_tracer().current_scope() current_graph._namespace.auto_naming_for_outputs(expr) current_graph._insert(expr) + active_module_tracer().current_constant_cache().append(expr.value) return expr.outputs[0] def interpret(self, *inputs): diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index 7bd8ab419fd0f21e8b4588fddab0fe404622b594..bfe4ea41cb221fc007662825fe73eeda0e1e9fe5 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -131,6 +131,7 @@ class module_tracer: self._active_scopes = [] self.checker = TracedModuleChecker(self) self.patcher = Patcher(wrap_fn) + self._activate_constant_cache = [] @classmethod def register_as_builtin(cls, mod): @@ -145,16 +146,28 @@ class module_tracer: def push_scope(self, scope): self._active_scopes.append(scope) self.checker.push_scope() + self._activate_constant_cache.append([]) + def pop_scope(self): self._active_scopes.pop() self.checker.pop_scope() + cache = self._activate_constant_cache.pop() + for obj in cache: + if hasattr(obj, "_NodeMixin__node"): + delattr(obj, "_NodeMixin__node") + def current_scope(self): if self._active_scopes: return self._active_scopes[-1] return None + def current_constant_cache(self): + if self._activate_constant_cache: + return self._activate_constant_cache[-1] + return None + def top_scope(self): if self._active_scopes: return self._active_scopes[0] diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index 364ab8f7a3d4111f46f2e3ba3ec1f09bf14525f3..079ee46ecab6c03dd0369785f95119305dca2a90 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -379,6 +379,11 @@ class NodeMixin(abc.ABC): if isinstance(value, NodeMixin): value._record_wrapped_nodes(node) + @classmethod + def clear_node(cls, value): + if hasattr(value, "_NodeMixin__node"): + delattr(value, "_NodeMixin__node") + @classmethod def get(cls, value, *default): return getattr(value, "_NodeMixin__node", *default) diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index c4c25094cadf1c8862e69c4d172c61b1dec81056..4950a4535020ebc58ce621279637244b0c39c348 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -1980,7 +1980,10 @@ class TracedModule(Module): assert ( treedef in self.argdef_graph_map ), "support input args kwargs format: \n{}, but get: \n{}".format( - "\n ".join("forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys()), + "\n ".join( + "forward({})".format(i._args_kwargs_repr()) + for i in self.argdef_graph_map.keys() + ), treedef._args_kwargs_repr(), ) inputs = filter( @@ -2514,3 +2517,7 @@ def trace_module( set_symbolic_shape(use_sym_shape) set_active_module_tracer(None) unset_module_tracing() + for t in mod.tensors(recursive=True): + NodeMixin.clear_node(t) + for t in inputs: + NodeMixin.clear_node(t)