提交 7751a067 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

docs(mge/tensor): add advanced index related docs

GitOrigin-RevId: 31735ddac487826aa604c2a2aec2d72b572fa609
上级 7b0dbe6a
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -10,7 +9,7 @@ import collections
import functools
import itertools
import weakref
from typing import Union
from typing import Callable, Tuple, Union
import numpy as np
......@@ -68,24 +67,38 @@ def _wrap_symbolvar_binary_op(f):
return wrapped
def wrap_slice(inp):
def _wrap_slice(inp: slice):
r"""
A wrapper to handle Tensor values in ``inp`` slice.
"""
start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start
stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop
step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step
return slice(start, stop, step)
def wrap_idx(idx):
def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]):
r"""
A wrapper to handle Tensor values in ``idx``.
"""
if not isinstance(idx, tuple):
idx = (idx,)
idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx)
idx = tuple(wrap_slice(i) if isinstance(i, slice) else i for i in idx)
idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx)
return idx
class MGBIndexWrapper:
def __init__(self, dest, mgb_index, val=None):
class _MGBIndexWrapper:
r"""
A wrapper class to handle ``__getitem__`` for index containing Tensor values.
:param dest: a destination Tensor to do indexing on.
:param mgb_index: an ``_internal`` helper function indicating how to index.
:param val: a optional Tensor parameter used for ``mgb_index``.
"""
def __init__(self, dest: "Tensor", mgb_index: Callable, val=None):
self.dest = dest
self.val = val
self.mgb_index = mgb_index
......@@ -93,16 +106,22 @@ class MGBIndexWrapper:
def __getitem__(self, idx):
if self.val is None:
return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)(
wrap_idx(idx)
_wrap_idx(idx)
)
else:
return wrap_io_tensor(
self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__
)(wrap_idx(idx))
)(_wrap_idx(idx))
class _Guard:
r"""
A wrapper class with custom ``__del__`` method calling ``deleter``.
:param deleter: a function to be called in ``__del__``.
"""
class Guard:
def __init__(self, deleter):
def __init__(self, deleter: Callable):
self.deleter = deleter
def __del__(self):
......@@ -161,6 +180,7 @@ class Tensor:
return self.__sym.inferred_value
def item(self):
r"""If tensor only has only one value, return it."""
return self.numpy().item()
def _attach(self, comp_graph, *, volatile=True):
......@@ -204,7 +224,7 @@ class Tensor:
if self is not None:
self.__sym_override = None
deleters.add(Guard(restore))
deleters.add(_Guard(restore))
self.__sym_override = symvar
@property
......@@ -403,43 +423,149 @@ class Tensor:
# mgb indexing family
def __getitem__(self, idx):
return wrap_io_tensor(self._symvar.__getitem__)(wrap_idx(idx))
return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx))
def set_subtensor(self, val: "Tensor"):
r"""
Return a object which supports using ``__getitem__`` to set subtensor.
def set_subtensor(self, val):
return MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``.
"""
return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
def incr_subtensor(self, val):
return MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
def incr_subtensor(self, val: "Tensor"):
r"""
Return a object which supports using ``__getitem__`` to increase subtensor.
``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
@property
def ai(self):
return MGBIndexWrapper(self, mgb.opr.advanced_indexing)
r"""
Return a object which supports complex index method to get subtensor.
def set_ai(self, val):
return MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
Examples:
def incr_ai(self, val):
return MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
.. testcode::
from megengine import tensor
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
print(a.ai[:, [2, 3]])
Outputs:
.. testoutput::
Tensor([[ 2. 3.]
[ 6. 7.]
[10. 11.]
[14. 15.]])
"""
return _MGBIndexWrapper(self, mgb.opr.advanced_indexing)
def set_ai(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
def incr_ai(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
@property
def mi(self):
return MGBIndexWrapper(self, mgb.opr.mesh_indexing)
r"""
Return a object which supports getting subtensor by
the coordinates which is Cartesian product of given index.
Examples:
def set_mi(self, val):
return MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
.. testcode::
from megengine import tensor
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
print(a.mi[[1, 2], [2, 3]])
# is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]]
# a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11
Outputs:
.. testoutput::
Tensor([[ 6. 7.]
[10. 11.]])
"""
return _MGBIndexWrapper(self, mgb.opr.mesh_indexing)
def set_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
def incr_mi(self, val):
return MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
def incr_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
@property
def batched_mi(self):
return MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
r"""
Return a object which supports getting subtensor by
batched mesh indexing.
def batched_set_mi(self, val):
return MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice.
Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``.
Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated.
And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below.
def batched_incr_mi(self, val):
return MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)
Examples:
.. testcode::
from megengine import tensor
a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4)))
print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]])
# is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1])
# and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1])
# a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77
print(a.batched_mi[:2, [[0],[1]], :2, :1])
# is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]``
Outputs:
.. testoutput::
Tensor([[[[ 0.]
[ 4.]]]
[[[73.]
[77.]]]])
Tensor([[[[ 0.]
[ 4.]]]
[[[64.]
[68.]]]])
"""
return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
def batched_set_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
def batched_incr_mi(self, val: "Tensor"):
r"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)
def __array__(self, dtype=None):
if dtype is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册