From 7751a0676ecc85f33e2f1b831b18872c36268813 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 8 Jun 2020 17:51:55 +0800 Subject: [PATCH] docs(mge/tensor): add advanced index related docs GitOrigin-RevId: 31735ddac487826aa604c2a2aec2d72b572fa609 --- python_module/megengine/core/tensor.py | 190 ++++++++++++++++++++----- 1 file changed, 158 insertions(+), 32 deletions(-) diff --git a/python_module/megengine/core/tensor.py b/python_module/megengine/core/tensor.py index 575ed7d53..5856ff1db 100644 --- a/python_module/megengine/core/tensor.py +++ b/python_module/megengine/core/tensor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -10,7 +9,7 @@ import collections import functools import itertools import weakref -from typing import Union +from typing import Callable, Tuple, Union import numpy as np @@ -68,24 +67,38 @@ def _wrap_symbolvar_binary_op(f): return wrapped -def wrap_slice(inp): +def _wrap_slice(inp: slice): + r""" + A wrapper to handle Tensor values in ``inp`` slice. + """ start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step return slice(start, stop, step) -def wrap_idx(idx): +def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]): + r""" + A wrapper to handle Tensor values in ``idx``. + """ if not isinstance(idx, tuple): idx = (idx,) idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx) - idx = tuple(wrap_slice(i) if isinstance(i, slice) else i for i in idx) + idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx) return idx -class MGBIndexWrapper: - def __init__(self, dest, mgb_index, val=None): +class _MGBIndexWrapper: + r""" + A wrapper class to handle ``__getitem__`` for index containing Tensor values. + + :param dest: a destination Tensor to do indexing on. + :param mgb_index: an ``_internal`` helper function indicating how to index. + :param val: a optional Tensor parameter used for ``mgb_index``. + """ + + def __init__(self, dest: "Tensor", mgb_index: Callable, val=None): self.dest = dest self.val = val self.mgb_index = mgb_index @@ -93,16 +106,22 @@ class MGBIndexWrapper: def __getitem__(self, idx): if self.val is None: return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)( - wrap_idx(idx) + _wrap_idx(idx) ) else: return wrap_io_tensor( self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__ - )(wrap_idx(idx)) + )(_wrap_idx(idx)) + + +class _Guard: + r""" + A wrapper class with custom ``__del__`` method calling ``deleter``. + :param deleter: a function to be called in ``__del__``. + """ -class Guard: - def __init__(self, deleter): + def __init__(self, deleter: Callable): self.deleter = deleter def __del__(self): @@ -161,6 +180,7 @@ class Tensor: return self.__sym.inferred_value def item(self): + r"""If tensor only has only one value, return it.""" return self.numpy().item() def _attach(self, comp_graph, *, volatile=True): @@ -204,7 +224,7 @@ class Tensor: if self is not None: self.__sym_override = None - deleters.add(Guard(restore)) + deleters.add(_Guard(restore)) self.__sym_override = symvar @property @@ -403,43 +423,149 @@ class Tensor: # mgb indexing family def __getitem__(self, idx): - return wrap_io_tensor(self._symvar.__getitem__)(wrap_idx(idx)) + return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx)) + + def set_subtensor(self, val: "Tensor"): + r""" + Return a object which supports using ``__getitem__`` to set subtensor. - def set_subtensor(self, val): - return MGBIndexWrapper(self, mgb.opr.set_subtensor, val) + ``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``. + """ + return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val) - def incr_subtensor(self, val): - return MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) + def incr_subtensor(self, val: "Tensor"): + r""" + Return a object which supports using ``__getitem__`` to increase subtensor. + + ``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``. + """ + return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) @property def ai(self): - return MGBIndexWrapper(self, mgb.opr.advanced_indexing) + r""" + Return a object which supports complex index method to get subtensor. - def set_ai(self, val): - return MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) + Examples: - def incr_ai(self, val): - return MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) + .. testcode:: + + from megengine import tensor + a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4))) + print(a.ai[:, [2, 3]]) + + Outputs: + + .. testoutput:: + + Tensor([[ 2. 3.] + [ 6. 7.] + [10. 11.] + [14. 15.]]) + """ + return _MGBIndexWrapper(self, mgb.opr.advanced_indexing) + + def set_ai(self, val: "Tensor"): + r""" + Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing. + """ + return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) + + def incr_ai(self, val: "Tensor"): + r""" + Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing. + """ + return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) @property def mi(self): - return MGBIndexWrapper(self, mgb.opr.mesh_indexing) + r""" + Return a object which supports getting subtensor by + the coordinates which is Cartesian product of given index. + + Examples: - def set_mi(self, val): - return MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) + .. testcode:: + + from megengine import tensor + a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4))) + print(a.mi[[1, 2], [2, 3]]) + # is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]] + # a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11 + + Outputs: + + .. testoutput:: + + Tensor([[ 6. 7.] + [10. 11.]]) + """ + return _MGBIndexWrapper(self, mgb.opr.mesh_indexing) + + def set_mi(self, val: "Tensor"): + r""" + Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing. + """ + return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) - def incr_mi(self, val): - return MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) + def incr_mi(self, val: "Tensor"): + r""" + Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing. + """ + return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) @property def batched_mi(self): - return MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) + r""" + Return a object which supports getting subtensor by + batched mesh indexing. - def batched_set_mi(self, val): - return MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) + For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice. + Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``. + Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated. + And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below. - def batched_incr_mi(self, val): - return MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val) + Examples: + + .. testcode:: + + from megengine import tensor + a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4))) + + print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]]) + # is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1]) + # and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1]) + # a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77 + + print(a.batched_mi[:2, [[0],[1]], :2, :1]) + # is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]`` + + Outputs: + + .. testoutput:: + + Tensor([[[[ 0.] + [ 4.]]] + [[[73.] + [77.]]]]) + Tensor([[[[ 0.] + [ 4.]]] + [[[64.] + [68.]]]]) + """ + return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) + + def batched_set_mi(self, val: "Tensor"): + r""" + Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. + """ + return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) + + def batched_incr_mi(self, val: "Tensor"): + r""" + Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. + """ + return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val) def __array__(self, dtype=None): if dtype is None: -- GitLab