From 1e2117f66d700f9c58c260b048558c67b0bd440e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 9 Oct 2020 22:51:57 +0800 Subject: [PATCH] feat(mge): remove add_axis and remove_axis GitOrigin-RevId: 59611d43f979ce2de3b8fdb078c3e8fca6e9bf49 --- .../megengine/core/tensor/tensor_wrapper.py | 2 +- .../python/megengine/functional/math.py | 10 ++--- imperative/python/megengine/functional/nn.py | 14 +++---- .../python/megengine/functional/tensor.py | 39 +++++++------------ .../python/test/unit/core/test_autodiff.py | 2 +- .../test/unit/functional/test_tensor.py | 4 +- 6 files changed, 31 insertions(+), 40 deletions(-) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 9bbc83866..800146e0e 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -411,7 +411,7 @@ class ArrayMethodMixin(abc.ABC): r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. If ``axis`` is a list of axises, reduce over all of them. - If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`). + If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.squeeze`). Same for prod/mean/max/min. diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index a61ea8dd7..2c0ab6094 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -19,7 +19,7 @@ from ..core.tensor import utils from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor import Tensor from .elemwise import clip, exp, log, log1p -from .tensor import add_axis, remove_axis, reshape +from .tensor import reshape, squeeze __all__ = [ "argmax", @@ -459,7 +459,7 @@ def argmin( (inp,) = apply(op, inp) if not keepdims: - inp = remove_axis(inp, ai) + inp = squeeze(inp, ai) return inp @@ -471,7 +471,7 @@ def argmin( op = builtin.Argmin(axis=axis) (result,) = apply(op, inp) if not keepdims: - result = remove_axis(result, axis) + result = squeeze(result, axis) return result @@ -517,7 +517,7 @@ def argmax( (inp,) = apply(op, inp) if not keepdims: - inp = remove_axis(inp, ai) + inp = squeeze(inp, ai) return inp @@ -529,7 +529,7 @@ def argmax( op = builtin.Argmax(axis=axis) (result,) = apply(op, inp) if not keepdims: - result = remove_axis(result, axis) + result = squeeze(result, axis) return result diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 0d811eb51..6d15ff3f2 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -23,7 +23,7 @@ from .debug_param import get_conv_execution_strategy from .distributed import all_reduce_sum from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .math import argsort, max, sum -from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshape, zeros +from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros from .types import _pair, _pair_nonzero __all__ = [ @@ -542,7 +542,7 @@ def logsumexp( if keepdims: return max_value + log(sum(exp(inp - max_value), axis, keepdims)) else: - return remove_axis(max_value, axis=None) + log( + return squeeze(max_value, axis=None) + log( sum(exp(inp - max_value), axis, keepdims) ) @@ -640,7 +640,7 @@ def batch_norm2d( def expand_or_full(x, value): if x is None: return full_value(value) - return add_axis(x, [0, 2, 3]) + return expand_dims(x, [0, 2, 3]) def make_full_if_none(x, value): if x is None: @@ -998,10 +998,10 @@ def matmul( else: if dim1 == 1: shp = (inp2.shape[1],) - inp1 = add_axis(inp1, 0) + inp1 = expand_dims(inp1, 0) if dim2 == 1: shp = (inp1.shape[0],) - inp2 = add_axis(inp2, 1) + inp2 = expand_dims(inp2, 1) op = builtin.MatrixMul( transposeA=transpose_a, transposeB=transpose_b, @@ -1135,7 +1135,7 @@ def interpolate( align_corners = False if mode == "LINEAR": - inp = add_axis(inp, 3) + inp = expand_dims(inp, 3) if inp.ndim != 4: raise ValueError("shape of input tensor must correspond to the operartion mode") @@ -1452,7 +1452,7 @@ def indexing_one_hot( index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) (result,) = apply(op, src, index) if not keepdims: - result = remove_axis(result, axis) + result = squeeze(result, axis) return result diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 84049c48c..76bd10552 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -32,12 +32,12 @@ from ..tensor import Tensor from .elemwise import ceil __all__ = [ - "add_axis", "arange", "broadcast", "concat", "cond_take", "dimshuffle", + "expand_dims", "eye", "flatten", "full", @@ -50,7 +50,6 @@ __all__ = [ "param_pack_concat", "param_pack_split", "reshape", - "remove_axis", "split", "squeeze", "stack", @@ -160,7 +159,7 @@ def zeros_like(inp: Tensor) -> Tensor: print(out.numpy()) Outputs: - + .. testoutput:: [[0 0 0] @@ -320,7 +319,7 @@ def stack(inps, axis=0, device=None): if len(shapes) != 1: raise ValueError("All input tensors must have the same shape") - inps = [add_axis(inp, axis=axis) for inp in inps] + inps = [expand_dims(inp, axis=axis) for inp in inps] return concat(inps, axis=axis, device=device) @@ -480,7 +479,7 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: - r"""Writes all values from the tensor source into input tensor + r"""Writes all values from the tensor source into input tensor at the indices specified in the index tensor. For each value in source, its output index is specified by its index @@ -694,7 +693,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: Swaps shapes and strides according to given pattern. :param inp: input tensor. - :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, + :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: * (``'x'``) -> make a 0d (scalar) into a 1d vector @@ -815,7 +814,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: return inp.reshape(*target_shape) -def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: +def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: r""" Adds dimension before given axis. @@ -832,7 +831,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: import megengine.functional as F x = tensor([1, 2]) - out = F.add_axis(x, 0) + out = F.expand_dims(x, 0) print(out.shape) Outputs: @@ -861,12 +860,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: return result -expand_dims = add_axis - - -def remove_axis( - inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None -) -> Tensor: +def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor: r""" Removes dimension of shape 1. @@ -883,7 +877,7 @@ def remove_axis( import megengine.functional as F x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) - out = F.remove_axis(x, 3) + out = F.squeeze(x, 3) print(out.shape) Outputs: @@ -896,9 +890,6 @@ def remove_axis( return _remove_axis(inp, axis) -squeeze = remove_axis - - def linspace( start: Union[int, float, Tensor], stop: Union[int, float, Tensor], @@ -925,7 +916,7 @@ def linspace( print(a.numpy()) Outputs: - + .. testoutput:: [ 3. 4.75 6.5 8.25 10. ] @@ -967,7 +958,7 @@ def arange( a = F.arange(5) print(a.numpy()) - + Outputs: Outputs: @@ -1018,9 +1009,9 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor: b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) print(b.numpy()) print(c.numpy()) - + Outputs: - + .. testoutput:: [1] @@ -1059,9 +1050,9 @@ def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor: offsets = tensor(offsets_val, np.int32) c = F.param_pack_concat([a, b], offsets, offsets_val) print(c.numpy()) - + Outputs: - + .. testoutput:: [1 1 1 1 1 1 1 1 1 1] diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 1cc7d4539..7e017e78b 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -306,7 +306,7 @@ def test_AxisAddRemove(): x = TensorWrapper(x_np) grad = Grad().wrt(x, callback=save_to(x)) - y = F.remove_axis(F.add_axis(x, 2), 0) + y = F.squeeze(F.expand_dims(x, 2), 0) grad(y, F.ones_like(y)) np.testing.assert_equal( diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 88b1bcdaf..b2ebafced 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -100,7 +100,7 @@ def test_squeeze(): for axis in [None, 3, -4, (3, -4)]: y = np.squeeze(x, axis) - yy = F.remove_axis(xx, axis) + yy = F.squeeze(xx, axis) np.testing.assert_equal(y, yy.numpy()) @@ -110,7 +110,7 @@ def test_expand_dims(): for axis in [2, -3, (3, -4), (1, -4)]: y = np.expand_dims(x, axis) - yy = F.add_axis(xx, axis) + yy = F.expand_dims(xx, axis) np.testing.assert_equal(y, yy.numpy()) -- GitLab