From fa4bf16800d58a8f29c85aeb41873a1d4b4e883a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 30 Sep 2020 22:54:17 +0800 Subject: [PATCH] feat(mge/functional): add repeat and tile opr GitOrigin-RevId: a20d4b6fb0684699175916385e78e3a49776efee --- .../python/megengine/autodiff/grad_manager.py | 4 +- .../python/megengine/core/tensor/indexing.py | 2 +- .../python/megengine/core/tensor/utils.py | 8 +- .../python/megengine/distributed/helper.py | 2 +- .../python/megengine/functional/inplace.py | 4 +- .../python/megengine/functional/tensor.py | 143 ++++++++++++++++++ imperative/python/megengine/tensor.py | 4 - imperative/python/src/tensor.cpp | 10 +- imperative/python/src/tensor.h | 1 + .../test/unit/functional/test_tensor.py | 50 ++++++ imperative/python/test/unit/test_tracing.py | 3 +- 11 files changed, 214 insertions(+), 17 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 3c619512d..70cbc5b70 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -279,8 +279,8 @@ class GradManager: tensor.grad = grad else: tensor.grad += grad - if tensor.isscalar() and tensor.grad is not None: - tensor.grad.setscalar() + if tensor._isscalar() and tensor.grad is not None: + tensor.grad._setscalar() finally: self.release() backwarding_grad_manager = cache diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index c172e65f8..e912d4718 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -225,7 +225,7 @@ def getitem(tensor, index): op = builtin.IndexingMultiAxisVec(items=items) (result,) = apply(op, tensor, *tensors) if ret_scalar: - result.setscalar() + result._setscalar() return result diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 2c2a309d3..784eceea5 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -51,10 +51,10 @@ def concatenate(inputs, axis=0, *, device=None): def astype(x, dtype): dtype = np.dtype(dtype) if not is_dtype_equal(x.dtype, dtype): - isscalar = x.isscalar() + isscalar = x._isscalar() (x,) = apply(builtin.TypeCvt(dtype=dtype), x) if isscalar: - x.setscalar() + x._setscalar() return x @@ -98,14 +98,14 @@ def result_type(*args): def isscalar(x): if isinstance(x, Tensor): - return x.isscalar() + return x._isscalar() return np.isscalar(x) def setscalar(x): if isinstance(x, Tensor): - x.setscalar() + x._setscalar() else: raise NotImplementedError("Unsupport type {}".format(type(x))) diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 0b65d411a..bc365c2eb 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): outputs = apply(op, inp) for s, x in zip(shapes, outputs): if not s: - x.setscalar() + x._setscalar() return outputs diff --git a/imperative/python/megengine/functional/inplace.py b/imperative/python/megengine/functional/inplace.py index 30b96f757..05126f504 100644 --- a/imperative/python/megengine/functional/inplace.py +++ b/imperative/python/megengine/functional/inplace.py @@ -12,8 +12,8 @@ from ..core.ops.builtin import InplaceAdd def _inplace_add_(dest, delta, alpha, beta): - isscalar = dest.isscalar() + isscalar = dest._isscalar() dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) if isscalar: - dest.setscalar() + dest._setscalar() return dest diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 4218e52dd..6f7c63e23 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -44,11 +44,13 @@ __all__ = [ "linspace", "ones", "ones_like", + "repeat", "reshape", "split", "squeeze", "stack", "scatter", + "tile", "transpose", "where", "zeros", @@ -987,3 +989,144 @@ def arange( if np.dtype(dtype) == np.int32: return result.astype(dtype) return result + + +def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): + """ + Repeat elements of an array. + + :param inp: input tensor. + :param repeats: the number of repetitions for each element. + :param axis: the axis along which to repeat values. By default, use the + flattened input array, and return a flat output array. + :return: output tensor. + + Examples: + + .. testcode:: + + import numpy as np + import megengine.functional as F + from megengine import tensor + + x = tensor([[1, 2], [3, 4]], np.int32) + y = F.repeat(x, 2, axis=0) + print(y.numpy()) + + Outputs: + + .. testoutput:: + + [[1 2] + [1 2] + [3 4] + [3 4]] + + """ + if axis is None: + inp = inp.reshape(-1) # flatten + axis = 0 + if inp._isscalar(): + inp._unsetscalar() + shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) + # assume inp.ndim is not changed during trace + max_axis = len(shape) - 1 + assert axis >= 0 and axis <= max_axis + assert repeats >= 1 + + base_shape, bcast_shape, target_shape = [], [], [] + if axis != 0: + target_shape.append(shape[:axis]) + base_shape.extend([shape[: axis + 1], [1,]]) + bcast_shape.extend([shape[: axis + 1], [repeats,]]) + target_shape.extend( + [shape[axis] * repeats,] + ) + if axis + 1 <= max_axis: + base_shape.append(shape[axis + 1 :]) + bcast_shape.append(shape[axis + 1 :]) + target_shape.append(shape[axis + 1 :]) + + out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( + concat(target_shape) + ) + return out + + +def _tile_one_dim(inp, rep, axis): + shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) + # assume inp.ndim is not changed during trace + max_axis = len(shape) - 1 + + base_shape, bcast_shape, target_shape = [], [], [] + + if axis != 0: + base_shape.append(shape[:axis]) + bcast_shape.append(shape[:axis]) + target_shape.append(shape[:axis]) + base_shape.extend([[1,], shape[axis:]]) + bcast_shape.extend([rep, shape[axis:]]) + target_shape.append(shape[axis] * rep) + if axis + 1 <= max_axis: + target_shape.append(shape[axis + 1 :]) + + out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( + concat(target_shape) + ) + return out + + +def tile(inp: Tensor, reps: Iterable[int]): + """ + Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d, + the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``, + ``inp`` is promoted to be ``d``-dimensional by prepending new axis. + + :param inp: input tensor. + :param reps: The number of repetitions of inp along each axis. + :return: output tensor. + + Examples: + + .. testcode:: + + import numpy as np + import megengine.functional as F + from megengine import tensor + + x = tensor([[1, 2], [3, 4]], np.int32) + y = F.tile(x, (2,1)) + print(y.numpy()) + + Outputs: + + .. testoutput:: + + [[1 2] + [3 4] + [1 2] + [3 4]] + + """ + shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) + reps = astensor1d(reps, inp, dtype="int32", device=inp.device) + l_shape = len(shape) + l_reps = len(reps) + assert ( + l_reps >= l_shape + ), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor" + + for i in range(l_shape): + rep = reps[i + (l_reps - l_shape)] + inp = _tile_one_dim(inp, rep, i) + + if l_reps > l_shape: + shape = inp.shape + extra = reps[:-l_shape] + extra_ones = ones_like(extra) + base_shape = concat([extra_ones, shape]) + bcast_shape = concat([extra, shape]) + target_shape = concat([extra, shape]) + inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) + + return inp diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 8f6072271..65477d1fd 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -51,10 +51,6 @@ class Tensor(_Tensor, ArrayMethodMixin): cn = device._cn if isinstance(data, _Tensor): - if dtype is not None: - logger.warning( - "dtype does not work when creating a new Tensor with another Tensor" - ) obj = _Tensor.__new__(cls, data) else: if isinstance(data, np.ndarray): diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 0064192c6..453003598 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -557,6 +557,11 @@ void TensorWrapper::setscalar() { } +void TensorWrapper::unsetscalar() { + m_tensor->m_flags &= ~Tensor::Flags::SCALAR; +} + + struct TensorWeakRef { std::weak_ptr wptr; @@ -794,8 +799,9 @@ void init_tensor(py::module m) { .def_getset<&TensorWrapper::dtype>("dtype") .def_getset<&TensorWrapper::device>("device") .def<&TensorWrapper::reset>("_reset") - .def<&TensorWrapper::isscalar>("isscalar") - .def<&TensorWrapper::setscalar>("setscalar") + .def<&TensorWrapper::isscalar>("_isscalar") + .def<&TensorWrapper::setscalar>("_setscalar") + .def<&TensorWrapper::unsetscalar>("_unsetscalar") .def<&TensorWrapper::detach>("detach") .def<&TensorWrapper::_dev_tensor>("_dev_tensor") .def<&TensorWrapper::_swap_out>("_swap_out") diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 57c95069c..d3d5fef48 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -153,6 +153,7 @@ struct TensorWrapper { PyObject* detach(); PyObject* isscalar(); void setscalar(); + void unsetscalar(); PyObject* _dev_tensor(); void _swap_in(); void _swap_out(); diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index da6825ff3..0ffa5ebad 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -406,3 +406,53 @@ def test_copy_d2h(): def test_copy_d2d(): copy_test("gpu0", "gpu1") copy_test("gpu0:0", "gpu0:1") + + +@pytest.mark.parametrize( + "shape, repeats, axis", + [ + ((2,), 2, 0), + ((2, 3, 4, 5), 3, 0), + ((2, 3, 4, 5), 4, 3), + ((2,), 2, None), + ((2, 3, 4, 5), 3, None), + ((), 1, None), + ((), 10, None), + ], +) +def test_repeat(shape, repeats, axis): + def repeat_func(inp): + return F.repeat(inp=inp, repeats=repeats, axis=axis) + + if shape != (): + cases = [ + {"input": np.random.randn(*shape).astype("float32")}, + ] + else: + cases = [{"input": np.array(1.23)}] + + opr_test( + cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis), + ) + + +@pytest.mark.parametrize( + "shape, reps", + [ + ((2,), (2,)), + ((2, 3, 4, 5), (1, 1, 1, 1)), + ((2, 3, 4, 5), (1, 2, 3, 4)), + ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), + ], +) +def test_tile(shape, reps): + def tile_func(inp): + return F.tile(inp=inp, reps=reps) + + cases = [ + {"input": np.random.randn(*shape).astype("float32")}, + ] + + opr_test( + cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), + ) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index d838149d8..3742d6100 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import io +import itertools from tempfile import mkstemp import numpy as np @@ -359,7 +360,7 @@ def test_trace_warp_perspective(): np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) return out - for i in range(1): + for i in range(3): f(x, M) -- GitLab