提交 7751a067 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

docs(mge/tensor): add advanced index related docs

GitOrigin-RevId: 31735ddac487826aa604c2a2aec2d72b572fa609
上级 7b0dbe6a
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...@@ -10,7 +9,7 @@ import collections ...@@ -10,7 +9,7 @@ import collections
import functools import functools
import itertools import itertools
import weakref import weakref
from typing import Union from typing import Callable, Tuple, Union
import numpy as np import numpy as np
...@@ -68,24 +67,38 @@ def _wrap_symbolvar_binary_op(f): ...@@ -68,24 +67,38 @@ def _wrap_symbolvar_binary_op(f):
return wrapped return wrapped
def wrap_slice(inp): def _wrap_slice(inp: slice):
r"""
A wrapper to handle Tensor values in ``inp`` slice.
"""
start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start
stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop
step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step
return slice(start, stop, step) return slice(start, stop, step)
def wrap_idx(idx): def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]):
r"""
A wrapper to handle Tensor values in ``idx``.
"""
if not isinstance(idx, tuple): if not isinstance(idx, tuple):
idx = (idx,) idx = (idx,)
idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx) idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx)
idx = tuple(wrap_slice(i) if isinstance(i, slice) else i for i in idx) idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx)
return idx return idx
class MGBIndexWrapper: class _MGBIndexWrapper:
def __init__(self, dest, mgb_index, val=None): r"""
A wrapper class to handle ``__getitem__`` for index containing Tensor values.
:param dest: a destination Tensor to do indexing on.
:param mgb_index: an ``_internal`` helper function indicating how to index.
:param val: a optional Tensor parameter used for ``mgb_index``.
"""
def __init__(self, dest: "Tensor", mgb_index: Callable, val=None):
self.dest = dest self.dest = dest
self.val = val self.val = val
self.mgb_index = mgb_index self.mgb_index = mgb_index
...@@ -93,16 +106,22 @@ class MGBIndexWrapper: ...@@ -93,16 +106,22 @@ class MGBIndexWrapper:
def __getitem__(self, idx): def __getitem__(self, idx):
if self.val is None: if self.val is None:
return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)( return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)(
wrap_idx(idx) _wrap_idx(idx)
) )
else: else:
return wrap_io_tensor( return wrap_io_tensor(
self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__ self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__
)(wrap_idx(idx)) )(_wrap_idx(idx))
class _Guard:
r"""
A wrapper class with custom ``__del__`` method calling ``deleter``.
:param deleter: a function to be called in ``__del__``.
"""
class Guard: def __init__(self, deleter: Callable):
def __init__(self, deleter):
self.deleter = deleter self.deleter = deleter
def __del__(self): def __del__(self):
...@@ -161,6 +180,7 @@ class Tensor: ...@@ -161,6 +180,7 @@ class Tensor:
return self.__sym.inferred_value return self.__sym.inferred_value
def item(self): def item(self):
r"""If tensor only has only one value, return it."""
return self.numpy().item() return self.numpy().item()
def _attach(self, comp_graph, *, volatile=True): def _attach(self, comp_graph, *, volatile=True):
...@@ -204,7 +224,7 @@ class Tensor: ...@@ -204,7 +224,7 @@ class Tensor:
if self is not None: if self is not None:
self.__sym_override = None self.__sym_override = None
deleters.add(Guard(restore)) deleters.add(_Guard(restore))
self.__sym_override = symvar self.__sym_override = symvar
@property @property
...@@ -403,43 +423,149 @@ class Tensor: ...@@ -403,43 +423,149 @@ class Tensor:
# mgb indexing family # mgb indexing family
def __getitem__(self, idx): def __getitem__(self, idx):
return wrap_io_tensor(self._symvar.__getitem__)(wrap_idx(idx)) return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx))
def set_subtensor(self, val: "Tensor"):
r"""
Return a object which supports using ``__getitem__`` to set subtensor.
def set_subtensor(self, val): ``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``.
return MGBIndexWrapper(self, mgb.opr.set_subtensor, val) """
return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
def incr_subtensor(self, val): def incr_subtensor(self, val: "Tensor"):
return MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) r"""
Return a object which supports using ``__getitem__`` to increase subtensor.
``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
@property @property
def ai(self): def ai(self):
return MGBIndexWrapper(self, mgb.opr.advanced_indexing) r"""
Return a object which supports complex index method to get subtensor.
def set_ai(self, val): Examples:
return MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
def incr_ai(self, val): .. testcode::
return MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
from megengine import tensor
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
print(a.ai[:, [2, 3]])
Outputs:
.. testoutput::
Tensor([[ 2. 3.]
[ 6. 7.]
[10. 11.]
[14. 15.]])
"""
return _MGBIndexWrapper(self, mgb.opr.advanced_indexing)
def set_ai(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
def incr_ai(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
@property @property
def mi(self): def mi(self):
return MGBIndexWrapper(self, mgb.opr.mesh_indexing) r"""
Return a object which supports getting subtensor by
the coordinates which is Cartesian product of given index.
Examples:
def set_mi(self, val): .. testcode::
return MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
from megengine import tensor
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
print(a.mi[[1, 2], [2, 3]])
# is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]]
# a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11
Outputs:
.. testoutput::
Tensor([[ 6. 7.]
[10. 11.]])
"""
return _MGBIndexWrapper(self, mgb.opr.mesh_indexing)
def set_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
def incr_mi(self, val): def incr_mi(self, val: "Tensor"):
return MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
@property @property
def batched_mi(self): def batched_mi(self):
return MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) r"""
Return a object which supports getting subtensor by
batched mesh indexing.
def batched_set_mi(self, val): For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice.
return MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``.
Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated.
And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below.
def batched_incr_mi(self, val): Examples:
return MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)
.. testcode::
from megengine import tensor
a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4)))
print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]])
# is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1])
# and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1])
# a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77
print(a.batched_mi[:2, [[0],[1]], :2, :1])
# is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]``
Outputs:
.. testoutput::
Tensor([[[[ 0.]
[ 4.]]]
[[[73.]
[77.]]]])
Tensor([[[[ 0.]
[ 4.]]]
[[[64.]
[68.]]]])
"""
return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
def batched_set_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
def batched_incr_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)
def __array__(self, dtype=None): def __array__(self, dtype=None):
if dtype is None: if dtype is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册