From b3b14fdfe18ac5734e60818d62198f6dfd86483e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 13 May 2020 19:38:01 +0800 Subject: [PATCH] fix(mge/jit): fix add_update semantic GitOrigin-RevId: f541ac7c6d2dcef2f31c9d623ec92ef3a567f4db --- python_module/megengine/core/tensor.py | 37 ++++++++++++++++++--- python_module/megengine/functional/graph.py | 5 ++- python_module/megengine/jit/__init__.py | 2 ++ python_module/test/unit/jit/test_jit.py | 14 ++++++++ 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/python_module/megengine/core/tensor.py b/python_module/megengine/core/tensor.py index 37d0672d..b4aa4dda 100644 --- a/python_module/megengine/core/tensor.py +++ b/python_module/megengine/core/tensor.py @@ -9,6 +9,7 @@ import collections import functools import itertools +import weakref from typing import Union import numpy as np @@ -100,6 +101,14 @@ class MGBIndexWrapper: )(wrap_idx(idx)) +class Guard: + def __init__(self, deleter): + self.deleter = deleter + + def __del__(self): + self.deleter() + + class Tensor: r"""The main data container in MegEngine. Use :func:`~.tensor` to create a Tensor with existed data. @@ -111,6 +120,7 @@ class Tensor: self._reset(val, requires_grad=requires_grad) def _reset(self, val=None, *, requires_grad=None): + self.__sym_override = None if val is None: self.__val = None self.__sym = None @@ -154,17 +164,20 @@ class Tensor: return self.numpy().item() def _attach(self, comp_graph, *, volatile=True): + sym = self.__sym_override or self.__sym + if sym: + if sym.owner_graph != comp_graph: + raise RuntimeError("internal error") + return sym if self.__val: return self.__val.symvar(comp_graph, volatile=volatile) - if self.__sym: - if self.__sym.owner_graph != comp_graph: - raise RuntimeError("internal error") - return self.__sym else: raise ValueError("uninitialized") @property def _symvar(self): + if self.__sym_override: + return self.__sym_override if self.__sym: assert not self.__val return self.__sym @@ -174,10 +187,26 @@ class Tensor: return self._attach(get_default_graph()) def __mgb_symvar__(self, comp_graph=None, **_): + if self.__sym_override: + return self.__sym_override if self.__val and comp_graph: return self._attach(comp_graph) return self._symvar # read by mgb.opr + def _override_symvar_during_trace(self, trace, symvar): + assert self.__val and not self.__sym + assert trace is type(trace)._active_instance + deleters = trace._user_cache.setdefault(Tensor, set()) + self_ref = weakref.ref(self) + + def restore(): + self = self_ref() + if self is not None: + self.__sym_override = None + + deleters.add(Guard(restore)) + self.__sym_override = symvar + @property def dtype(self): r"""Return the data type of the tensor. diff --git a/python_module/megengine/functional/graph.py b/python_module/megengine/functional/graph.py index a5e6f6e1..5dbdadb6 100644 --- a/python_module/megengine/functional/graph.py +++ b/python_module/megengine/functional/graph.py @@ -13,7 +13,7 @@ import megengine._internal as mgb from ..core.graph import get_default_graph from ..core.tensor import Tensor, wrap_io_tensor -from ..jit import barrier, mark_impure +from ..jit import barrier, mark_impure, trace @wrap_io_tensor @@ -112,6 +112,9 @@ def add_update( ) mark_impure(u) + if trace._active_instance: + dest._override_symvar_during_trace(trace._active_instance, u) + return Tensor(u) diff --git a/python_module/megengine/jit/__init__.py b/python_module/megengine/jit/__init__.py index e8f06326..7da88c91 100644 --- a/python_module/megengine/jit/__init__.py +++ b/python_module/megengine/jit/__init__.py @@ -367,10 +367,12 @@ class trace: raise RuntimeError("nested trace is unsupported") self._status = self._STARTED type(self)._active_instance = self + self._user_cache = {} try: yield finally: self._status = self._FINISHED + self._user_cache = None type(self)._active_instance = None def _run_wrapped(self): diff --git a/python_module/test/unit/jit/test_jit.py b/python_module/test/unit/jit/test_jit.py index b90655aa..d2b6eecb 100644 --- a/python_module/test/unit/jit/test_jit.py +++ b/python_module/test/unit/jit/test_jit.py @@ -16,6 +16,7 @@ import pytest import megengine as mge import megengine._internal as mgb import megengine.module as M +from megengine import functional as F from megengine import jit, tensor from megengine.core.tensor import Tensor from megengine.jit import SublinearMemoryConfig @@ -57,6 +58,19 @@ def test_symbolic(): f.trace(0) +def test_add_update_semantic(): + for symbolic in [False, True]: + x = tensor(0) + + @jit.trace(symbolic=symbolic) + def f(): + F.add_update(x, 1) + return x + 1 + + np.testing.assert_equal(f().numpy(), [2]) + np.testing.assert_equal(f().numpy(), [3]) + + def test_dump(): @jit.trace(symbolic=True) def f(x, y): -- GitLab