From b8d8886e3518590c436d2d0a05e6506f6c001d1d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 8 Jun 2020 17:53:53 +0800 Subject: [PATCH] refactor(mge/tensor): combine Dict and TensorDict GitOrigin-RevId: 6b6c03c04b7c97c29d30831e07c21e8afd3c3f40 --- python_module/megengine/core/tensor.py | 70 ++++++++++++-------------- 1 file changed, 31 insertions(+), 39 deletions(-) diff --git a/python_module/megengine/core/tensor.py b/python_module/megengine/core/tensor.py index 5856ff1db..8aefcb31d 100644 --- a/python_module/megengine/core/tensor.py +++ b/python_module/megengine/core/tensor.py @@ -425,7 +425,7 @@ class Tensor: def __getitem__(self, idx): return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx)) - def set_subtensor(self, val: "Tensor"): + def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper: r""" Return a object which supports using ``__getitem__`` to set subtensor. @@ -433,7 +433,7 @@ class Tensor: """ return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val) - def incr_subtensor(self, val: "Tensor"): + def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper: r""" Return a object which supports using ``__getitem__`` to increase subtensor. @@ -442,7 +442,7 @@ class Tensor: return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) @property - def ai(self): + def ai(self) -> _MGBIndexWrapper: r""" Return a object which supports complex index method to get subtensor. @@ -465,20 +465,20 @@ class Tensor: """ return _MGBIndexWrapper(self, mgb.opr.advanced_indexing) - def set_ai(self, val: "Tensor"): + def set_ai(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing. """ return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) - def incr_ai(self, val: "Tensor"): + def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing. """ return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) @property - def mi(self): + def mi(self) -> _MGBIndexWrapper: r""" Return a object which supports getting subtensor by the coordinates which is Cartesian product of given index. @@ -502,20 +502,20 @@ class Tensor: """ return _MGBIndexWrapper(self, mgb.opr.mesh_indexing) - def set_mi(self, val: "Tensor"): + def set_mi(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing. """ return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) - def incr_mi(self, val: "Tensor"): + def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing. """ return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) @property - def batched_mi(self): + def batched_mi(self) -> _MGBIndexWrapper: r""" Return a object which supports getting subtensor by batched mesh indexing. @@ -555,13 +555,13 @@ class Tensor: """ return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) - def batched_set_mi(self, val: "Tensor"): + def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. """ return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) - def batched_incr_mi(self, val: "Tensor"): + def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper: r""" Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. """ @@ -680,18 +680,31 @@ def tensor( return Tensor(shared_nd, requires_grad=requires_grad) -class Dict(collections.MutableMapping): - def __init__(self, *args, key=None, **kwargs): +class TensorDict(collections.MutableMapping): + r""" + A helper class to maintain dict with Tensor key. + """ + + def __init__(self, *args, **kwargs): self.data = {} - if key: - self.keyfn = key for i in args: self.update(i) self.update(**kwargs) - @staticmethod - def keyfn(key): # pylint: disable=method-hidden - return key + class keyfn: + def __new__(cls, x: Tensor): + if not isinstance(x, Tensor): + return x + return super().__new__(cls) + + def __init__(self, x: Tensor): + self._data = x # do not save id directly to make pickle work + + def __hash__(self): + return id(self._data) + + def __eq__(self, other): + return isinstance(other, type(self)) and id(self._data) == id(other._data) def __getitem__(self, key): _, v = self.data[self.keyfn(key)] @@ -709,24 +722,3 @@ class Dict(collections.MutableMapping): def __len__(self): return len(self.data) - - -class TensorDict(Dict): # pylint: disable=too-many-ancestors - class keyfn: - def __new__(cls, x: Tensor): - if not isinstance(x, Tensor): - return x - return super().__new__(cls) - - def __init__(self, x: Tensor): - self._data = x # do not save id directly to make pickle work - - def __hash__(self): - return id(self._data) - - def __eq__(self, other): - # pylint: disable=undefined-variable - return isinstance(other, __class__) and id(self._data) == id(other._data) - - def __init__(self, *args): - super().__init__(*args) -- GitLab