提交 b8d8886e 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

refactor(mge/tensor): combine Dict and TensorDict

GitOrigin-RevId: 6b6c03c04b7c97c29d30831e07c21e8afd3c3f40
上级 7751a067
...@@ -425,7 +425,7 @@ class Tensor: ...@@ -425,7 +425,7 @@ class Tensor:
def __getitem__(self, idx): def __getitem__(self, idx):
return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx)) return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx))
def set_subtensor(self, val: "Tensor"): def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Return a object which supports using ``__getitem__`` to set subtensor. Return a object which supports using ``__getitem__`` to set subtensor.
...@@ -433,7 +433,7 @@ class Tensor: ...@@ -433,7 +433,7 @@ class Tensor:
""" """
return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val) return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
def incr_subtensor(self, val: "Tensor"): def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Return a object which supports using ``__getitem__`` to increase subtensor. Return a object which supports using ``__getitem__`` to increase subtensor.
...@@ -442,7 +442,7 @@ class Tensor: ...@@ -442,7 +442,7 @@ class Tensor:
return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
@property @property
def ai(self): def ai(self) -> _MGBIndexWrapper:
r""" r"""
Return a object which supports complex index method to get subtensor. Return a object which supports complex index method to get subtensor.
...@@ -465,20 +465,20 @@ class Tensor: ...@@ -465,20 +465,20 @@ class Tensor:
""" """
return _MGBIndexWrapper(self, mgb.opr.advanced_indexing) return _MGBIndexWrapper(self, mgb.opr.advanced_indexing)
def set_ai(self, val: "Tensor"): def set_ai(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing. Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
""" """
return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
def incr_ai(self, val: "Tensor"): def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing. Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
""" """
return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
@property @property
def mi(self): def mi(self) -> _MGBIndexWrapper:
r""" r"""
Return a object which supports getting subtensor by Return a object which supports getting subtensor by
the coordinates which is Cartesian product of given index. the coordinates which is Cartesian product of given index.
...@@ -502,20 +502,20 @@ class Tensor: ...@@ -502,20 +502,20 @@ class Tensor:
""" """
return _MGBIndexWrapper(self, mgb.opr.mesh_indexing) return _MGBIndexWrapper(self, mgb.opr.mesh_indexing)
def set_mi(self, val: "Tensor"): def set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing. Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
""" """
return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
def incr_mi(self, val: "Tensor"): def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing. Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
""" """
return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
@property @property
def batched_mi(self): def batched_mi(self) -> _MGBIndexWrapper:
r""" r"""
Return a object which supports getting subtensor by Return a object which supports getting subtensor by
batched mesh indexing. batched mesh indexing.
...@@ -555,13 +555,13 @@ class Tensor: ...@@ -555,13 +555,13 @@ class Tensor:
""" """
return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
def batched_set_mi(self, val: "Tensor"): def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
""" """
return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
def batched_incr_mi(self, val: "Tensor"): def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
r""" r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
""" """
...@@ -680,18 +680,31 @@ def tensor( ...@@ -680,18 +680,31 @@ def tensor(
return Tensor(shared_nd, requires_grad=requires_grad) return Tensor(shared_nd, requires_grad=requires_grad)
class Dict(collections.MutableMapping): class TensorDict(collections.MutableMapping):
def __init__(self, *args, key=None, **kwargs): r"""
A helper class to maintain dict with Tensor key.
"""
def __init__(self, *args, **kwargs):
self.data = {} self.data = {}
if key:
self.keyfn = key
for i in args: for i in args:
self.update(i) self.update(i)
self.update(**kwargs) self.update(**kwargs)
@staticmethod class keyfn:
def keyfn(key): # pylint: disable=method-hidden def __new__(cls, x: Tensor):
return key if not isinstance(x, Tensor):
return x
return super().__new__(cls)
def __init__(self, x: Tensor):
self._data = x # do not save id directly to make pickle work
def __hash__(self):
return id(self._data)
def __eq__(self, other):
return isinstance(other, type(self)) and id(self._data) == id(other._data)
def __getitem__(self, key): def __getitem__(self, key):
_, v = self.data[self.keyfn(key)] _, v = self.data[self.keyfn(key)]
...@@ -709,24 +722,3 @@ class Dict(collections.MutableMapping): ...@@ -709,24 +722,3 @@ class Dict(collections.MutableMapping):
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
class TensorDict(Dict): # pylint: disable=too-many-ancestors
class keyfn:
def __new__(cls, x: Tensor):
if not isinstance(x, Tensor):
return x
return super().__new__(cls)
def __init__(self, x: Tensor):
self._data = x # do not save id directly to make pickle work
def __hash__(self):
return id(self._data)
def __eq__(self, other):
# pylint: disable=undefined-variable
return isinstance(other, __class__) and id(self._data) == id(other._data)
def __init__(self, *args):
super().__init__(*args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册