提交 b8d8886e 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

refactor(mge/tensor): combine Dict and TensorDict

GitOrigin-RevId: 6b6c03c04b7c97c29d30831e07c21e8afd3c3f40
上级 7751a067
......@@ -425,7 +425,7 @@ class Tensor:
def __getitem__(self, idx):
return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx))
def set_subtensor(self, val: "Tensor"):
def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
r"""
Return a object which supports using ``__getitem__`` to set subtensor.
......@@ -433,7 +433,7 @@ class Tensor:
"""
return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
def incr_subtensor(self, val: "Tensor"):
def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
r"""
Return a object which supports using ``__getitem__`` to increase subtensor.
......@@ -442,7 +442,7 @@ class Tensor:
return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
@property
def ai(self):
def ai(self) -> _MGBIndexWrapper:
r"""
Return a object which supports complex index method to get subtensor.
......@@ -465,20 +465,20 @@ class Tensor:
"""
return _MGBIndexWrapper(self, mgb.opr.advanced_indexing)
def set_ai(self, val: "Tensor"):
def set_ai(self, val: "Tensor") -> _MGBIndexWrapper:
r"""
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
def incr_ai(self, val: "Tensor"):
def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper:
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
@property
def mi(self):
def mi(self) -> _MGBIndexWrapper:
r"""
Return a object which supports getting subtensor by
the coordinates which is Cartesian product of given index.
......@@ -502,20 +502,20 @@ class Tensor:
"""
return _MGBIndexWrapper(self, mgb.opr.mesh_indexing)
def set_mi(self, val: "Tensor"):
def set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r"""
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
def incr_mi(self, val: "Tensor"):
def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
@property
def batched_mi(self):
def batched_mi(self) -> _MGBIndexWrapper:
r"""
Return a object which supports getting subtensor by
batched mesh indexing.
......@@ -555,13 +555,13 @@ class Tensor:
"""
return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
def batched_set_mi(self, val: "Tensor"):
def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
def batched_incr_mi(self, val: "Tensor"):
def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
......@@ -680,18 +680,31 @@ def tensor(
return Tensor(shared_nd, requires_grad=requires_grad)
class Dict(collections.MutableMapping):
def __init__(self, *args, key=None, **kwargs):
class TensorDict(collections.MutableMapping):
r"""
A helper class to maintain dict with Tensor key.
"""
def __init__(self, *args, **kwargs):
self.data = {}
if key:
self.keyfn = key
for i in args:
self.update(i)
self.update(**kwargs)
@staticmethod
def keyfn(key): # pylint: disable=method-hidden
return key
class keyfn:
def __new__(cls, x: Tensor):
if not isinstance(x, Tensor):
return x
return super().__new__(cls)
def __init__(self, x: Tensor):
self._data = x # do not save id directly to make pickle work
def __hash__(self):
return id(self._data)
def __eq__(self, other):
return isinstance(other, type(self)) and id(self._data) == id(other._data)
def __getitem__(self, key):
_, v = self.data[self.keyfn(key)]
......@@ -709,24 +722,3 @@ class Dict(collections.MutableMapping):
def __len__(self):
return len(self.data)
class TensorDict(Dict): # pylint: disable=too-many-ancestors
class keyfn:
def __new__(cls, x: Tensor):
if not isinstance(x, Tensor):
return x
return super().__new__(cls)
def __init__(self, x: Tensor):
self._data = x # do not save id directly to make pickle work
def __hash__(self):
return id(self._data)
def __eq__(self, other):
# pylint: disable=undefined-variable
return isinstance(other, __class__) and id(self._data) == id(other._data)
def __init__(self, *args):
super().__init__(*args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册