diff --git a/python_module/megengine/core/tensor.py b/python_module/megengine/core/tensor.py index 37d0672d9c55acf243afbe486eef23430d1a370c..b4aa4dda981c0499fbd2376f8351e0f368a61898 100644 --- a/python_module/megengine/core/tensor.py +++ b/python_module/megengine/core/tensor.py @@ -9,6 +9,7 @@ import collections import functools import itertools +import weakref from typing import Union import numpy as np @@ -100,6 +101,14 @@ class MGBIndexWrapper: )(wrap_idx(idx)) +class Guard: + def __init__(self, deleter): + self.deleter = deleter + + def __del__(self): + self.deleter() + + class Tensor: r"""The main data container in MegEngine. Use :func:`~.tensor` to create a Tensor with existed data. @@ -111,6 +120,7 @@ class Tensor: self._reset(val, requires_grad=requires_grad) def _reset(self, val=None, *, requires_grad=None): + self.__sym_override = None if val is None: self.__val = None self.__sym = None @@ -154,17 +164,20 @@ class Tensor: return self.numpy().item() def _attach(self, comp_graph, *, volatile=True): + sym = self.__sym_override or self.__sym + if sym: + if sym.owner_graph != comp_graph: + raise RuntimeError("internal error") + return sym if self.__val: return self.__val.symvar(comp_graph, volatile=volatile) - if self.__sym: - if self.__sym.owner_graph != comp_graph: - raise RuntimeError("internal error") - return self.__sym else: raise ValueError("uninitialized") @property def _symvar(self): + if self.__sym_override: + return self.__sym_override if self.__sym: assert not self.__val return self.__sym @@ -174,10 +187,26 @@ class Tensor: return self._attach(get_default_graph()) def __mgb_symvar__(self, comp_graph=None, **_): + if self.__sym_override: + return self.__sym_override if self.__val and comp_graph: return self._attach(comp_graph) return self._symvar # read by mgb.opr + def _override_symvar_during_trace(self, trace, symvar): + assert self.__val and not self.__sym + assert trace is type(trace)._active_instance + deleters = trace._user_cache.setdefault(Tensor, set()) + self_ref = weakref.ref(self) + + def restore(): + self = self_ref() + if self is not None: + self.__sym_override = None + + deleters.add(Guard(restore)) + self.__sym_override = symvar + @property def dtype(self): r"""Return the data type of the tensor. diff --git a/python_module/megengine/functional/graph.py b/python_module/megengine/functional/graph.py index a5e6f6e17eccb14ad6e72da4f1fdcd456b132bc6..5dbdadb60386be8411808f1ae3c79f942c0affd5 100644 --- a/python_module/megengine/functional/graph.py +++ b/python_module/megengine/functional/graph.py @@ -13,7 +13,7 @@ import megengine._internal as mgb from ..core.graph import get_default_graph from ..core.tensor import Tensor, wrap_io_tensor -from ..jit import barrier, mark_impure +from ..jit import barrier, mark_impure, trace @wrap_io_tensor @@ -112,6 +112,9 @@ def add_update( ) mark_impure(u) + if trace._active_instance: + dest._override_symvar_during_trace(trace._active_instance, u) + return Tensor(u) diff --git a/python_module/megengine/jit/__init__.py b/python_module/megengine/jit/__init__.py index e8f063267f4b9adbb25b0c26809b7364ab6c0452..7da88c915e20ee0fddd0e0b9c66a2ca9e97533a9 100644 --- a/python_module/megengine/jit/__init__.py +++ b/python_module/megengine/jit/__init__.py @@ -367,10 +367,12 @@ class trace: raise RuntimeError("nested trace is unsupported") self._status = self._STARTED type(self)._active_instance = self + self._user_cache = {} try: yield finally: self._status = self._FINISHED + self._user_cache = None type(self)._active_instance = None def _run_wrapped(self): diff --git a/python_module/test/unit/jit/test_jit.py b/python_module/test/unit/jit/test_jit.py index b90655aaa94b8f3c4d35e4aa373cee54739c8f57..d2b6eecbf33160a9d39b90eb7e97fdb3f9799b60 100644 --- a/python_module/test/unit/jit/test_jit.py +++ b/python_module/test/unit/jit/test_jit.py @@ -16,6 +16,7 @@ import pytest import megengine as mge import megengine._internal as mgb import megengine.module as M +from megengine import functional as F from megengine import jit, tensor from megengine.core.tensor import Tensor from megengine.jit import SublinearMemoryConfig @@ -57,6 +58,19 @@ def test_symbolic(): f.trace(0) +def test_add_update_semantic(): + for symbolic in [False, True]: + x = tensor(0) + + @jit.trace(symbolic=symbolic) + def f(): + F.add_update(x, 1) + return x + 1 + + np.testing.assert_equal(f().numpy(), [2]) + np.testing.assert_equal(f().numpy(), [3]) + + def test_dump(): @jit.trace(symbolic=True) def f(x, y):