提交 b3b14fdf 编写于 作者: M Megvii Engine Team

fix(mge/jit): fix add_update semantic

GitOrigin-RevId: f541ac7c6d2dcef2f31c9d623ec92ef3a567f4db
上级 2bbce2f9
......@@ -9,6 +9,7 @@
import collections
import functools
import itertools
import weakref
from typing import Union
import numpy as np
......@@ -100,6 +101,14 @@ class MGBIndexWrapper:
)(wrap_idx(idx))
class Guard:
def __init__(self, deleter):
self.deleter = deleter
def __del__(self):
self.deleter()
class Tensor:
r"""The main data container in MegEngine.
Use :func:`~.tensor` to create a Tensor with existed data.
......@@ -111,6 +120,7 @@ class Tensor:
self._reset(val, requires_grad=requires_grad)
def _reset(self, val=None, *, requires_grad=None):
self.__sym_override = None
if val is None:
self.__val = None
self.__sym = None
......@@ -154,17 +164,20 @@ class Tensor:
return self.numpy().item()
def _attach(self, comp_graph, *, volatile=True):
sym = self.__sym_override or self.__sym
if sym:
if sym.owner_graph != comp_graph:
raise RuntimeError("internal error")
return sym
if self.__val:
return self.__val.symvar(comp_graph, volatile=volatile)
if self.__sym:
if self.__sym.owner_graph != comp_graph:
raise RuntimeError("internal error")
return self.__sym
else:
raise ValueError("uninitialized")
@property
def _symvar(self):
if self.__sym_override:
return self.__sym_override
if self.__sym:
assert not self.__val
return self.__sym
......@@ -174,10 +187,26 @@ class Tensor:
return self._attach(get_default_graph())
def __mgb_symvar__(self, comp_graph=None, **_):
if self.__sym_override:
return self.__sym_override
if self.__val and comp_graph:
return self._attach(comp_graph)
return self._symvar # read by mgb.opr
def _override_symvar_during_trace(self, trace, symvar):
assert self.__val and not self.__sym
assert trace is type(trace)._active_instance
deleters = trace._user_cache.setdefault(Tensor, set())
self_ref = weakref.ref(self)
def restore():
self = self_ref()
if self is not None:
self.__sym_override = None
deleters.add(Guard(restore))
self.__sym_override = symvar
@property
def dtype(self):
r"""Return the data type of the tensor.
......
......@@ -13,7 +13,7 @@ import megengine._internal as mgb
from ..core.graph import get_default_graph
from ..core.tensor import Tensor, wrap_io_tensor
from ..jit import barrier, mark_impure
from ..jit import barrier, mark_impure, trace
@wrap_io_tensor
......@@ -112,6 +112,9 @@ def add_update(
)
mark_impure(u)
if trace._active_instance:
dest._override_symvar_during_trace(trace._active_instance, u)
return Tensor(u)
......
......@@ -367,10 +367,12 @@ class trace:
raise RuntimeError("nested trace is unsupported")
self._status = self._STARTED
type(self)._active_instance = self
self._user_cache = {}
try:
yield
finally:
self._status = self._FINISHED
self._user_cache = None
type(self)._active_instance = None
def _run_wrapped(self):
......
......@@ -16,6 +16,7 @@ import pytest
import megengine as mge
import megengine._internal as mgb
import megengine.module as M
from megengine import functional as F
from megengine import jit, tensor
from megengine.core.tensor import Tensor
from megengine.jit import SublinearMemoryConfig
......@@ -57,6 +58,19 @@ def test_symbolic():
f.trace(0)
def test_add_update_semantic():
for symbolic in [False, True]:
x = tensor(0)
@jit.trace(symbolic=symbolic)
def f():
F.add_update(x, 1)
return x + 1
np.testing.assert_equal(f().numpy(), [2])
np.testing.assert_equal(f().numpy(), [3])
def test_dump():
@jit.trace(symbolic=True)
def f(x, y):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册