提交 b3b14fdf 编写于 作者: M Megvii Engine Team

fix(mge/jit): fix add_update semantic

GitOrigin-RevId: f541ac7c6d2dcef2f31c9d623ec92ef3a567f4db
上级 2bbce2f9
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import collections import collections
import functools import functools
import itertools import itertools
import weakref
from typing import Union from typing import Union
import numpy as np import numpy as np
...@@ -100,6 +101,14 @@ class MGBIndexWrapper: ...@@ -100,6 +101,14 @@ class MGBIndexWrapper:
)(wrap_idx(idx)) )(wrap_idx(idx))
class Guard:
def __init__(self, deleter):
self.deleter = deleter
def __del__(self):
self.deleter()
class Tensor: class Tensor:
r"""The main data container in MegEngine. r"""The main data container in MegEngine.
Use :func:`~.tensor` to create a Tensor with existed data. Use :func:`~.tensor` to create a Tensor with existed data.
...@@ -111,6 +120,7 @@ class Tensor: ...@@ -111,6 +120,7 @@ class Tensor:
self._reset(val, requires_grad=requires_grad) self._reset(val, requires_grad=requires_grad)
def _reset(self, val=None, *, requires_grad=None): def _reset(self, val=None, *, requires_grad=None):
self.__sym_override = None
if val is None: if val is None:
self.__val = None self.__val = None
self.__sym = None self.__sym = None
...@@ -154,17 +164,20 @@ class Tensor: ...@@ -154,17 +164,20 @@ class Tensor:
return self.numpy().item() return self.numpy().item()
def _attach(self, comp_graph, *, volatile=True): def _attach(self, comp_graph, *, volatile=True):
sym = self.__sym_override or self.__sym
if sym:
if sym.owner_graph != comp_graph:
raise RuntimeError("internal error")
return sym
if self.__val: if self.__val:
return self.__val.symvar(comp_graph, volatile=volatile) return self.__val.symvar(comp_graph, volatile=volatile)
if self.__sym:
if self.__sym.owner_graph != comp_graph:
raise RuntimeError("internal error")
return self.__sym
else: else:
raise ValueError("uninitialized") raise ValueError("uninitialized")
@property @property
def _symvar(self): def _symvar(self):
if self.__sym_override:
return self.__sym_override
if self.__sym: if self.__sym:
assert not self.__val assert not self.__val
return self.__sym return self.__sym
...@@ -174,10 +187,26 @@ class Tensor: ...@@ -174,10 +187,26 @@ class Tensor:
return self._attach(get_default_graph()) return self._attach(get_default_graph())
def __mgb_symvar__(self, comp_graph=None, **_): def __mgb_symvar__(self, comp_graph=None, **_):
if self.__sym_override:
return self.__sym_override
if self.__val and comp_graph: if self.__val and comp_graph:
return self._attach(comp_graph) return self._attach(comp_graph)
return self._symvar # read by mgb.opr return self._symvar # read by mgb.opr
def _override_symvar_during_trace(self, trace, symvar):
assert self.__val and not self.__sym
assert trace is type(trace)._active_instance
deleters = trace._user_cache.setdefault(Tensor, set())
self_ref = weakref.ref(self)
def restore():
self = self_ref()
if self is not None:
self.__sym_override = None
deleters.add(Guard(restore))
self.__sym_override = symvar
@property @property
def dtype(self): def dtype(self):
r"""Return the data type of the tensor. r"""Return the data type of the tensor.
......
...@@ -13,7 +13,7 @@ import megengine._internal as mgb ...@@ -13,7 +13,7 @@ import megengine._internal as mgb
from ..core.graph import get_default_graph from ..core.graph import get_default_graph
from ..core.tensor import Tensor, wrap_io_tensor from ..core.tensor import Tensor, wrap_io_tensor
from ..jit import barrier, mark_impure from ..jit import barrier, mark_impure, trace
@wrap_io_tensor @wrap_io_tensor
...@@ -112,6 +112,9 @@ def add_update( ...@@ -112,6 +112,9 @@ def add_update(
) )
mark_impure(u) mark_impure(u)
if trace._active_instance:
dest._override_symvar_during_trace(trace._active_instance, u)
return Tensor(u) return Tensor(u)
......
...@@ -367,10 +367,12 @@ class trace: ...@@ -367,10 +367,12 @@ class trace:
raise RuntimeError("nested trace is unsupported") raise RuntimeError("nested trace is unsupported")
self._status = self._STARTED self._status = self._STARTED
type(self)._active_instance = self type(self)._active_instance = self
self._user_cache = {}
try: try:
yield yield
finally: finally:
self._status = self._FINISHED self._status = self._FINISHED
self._user_cache = None
type(self)._active_instance = None type(self)._active_instance = None
def _run_wrapped(self): def _run_wrapped(self):
......
...@@ -16,6 +16,7 @@ import pytest ...@@ -16,6 +16,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine._internal as mgb import megengine._internal as mgb
import megengine.module as M import megengine.module as M
from megengine import functional as F
from megengine import jit, tensor from megengine import jit, tensor
from megengine.core.tensor import Tensor from megengine.core.tensor import Tensor
from megengine.jit import SublinearMemoryConfig from megengine.jit import SublinearMemoryConfig
...@@ -57,6 +58,19 @@ def test_symbolic(): ...@@ -57,6 +58,19 @@ def test_symbolic():
f.trace(0) f.trace(0)
def test_add_update_semantic():
for symbolic in [False, True]:
x = tensor(0)
@jit.trace(symbolic=symbolic)
def f():
F.add_update(x, 1)
return x + 1
np.testing.assert_equal(f().numpy(), [2])
np.testing.assert_equal(f().numpy(), [3])
def test_dump(): def test_dump():
@jit.trace(symbolic=True) @jit.trace(symbolic=True)
def f(x, y): def f(x, y):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册