From 5c7d48cdb9c473c80563a95c53099caabd219687 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 4 Jan 2021 18:36:27 +0800 Subject: [PATCH] fix(mge/functional): fix tensor split GitOrigin-RevId: 0a112ab0bdaa82202c50f7f7b9fe05248b22e415 --- .../python/megengine/functional/elemwise.py | 2 +- .../python/megengine/functional/tensor.py | 87 ++++++++++++------- .../test/unit/functional/test_tensor.py | 26 +++++- 3 files changed, 78 insertions(+), 37 deletions(-) diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 3148a8fa7..878c427a9 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -158,7 +158,7 @@ def div(x, y): def floor_div(x, y): """Element-wise `floor(x / y)`.""" - return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIVIDE) + return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIV) def neg(x): diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index dac5c7d0e..eec1ba2c6 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -28,7 +28,7 @@ from ..core.tensor.utils import ( ) from ..device import get_default_device from ..tensor import Tensor -from .elemwise import ceil +from .elemwise import ceil, floor_div __all__ = [ "arange", @@ -324,52 +324,73 @@ def split(inp, nsplits_or_sections, axis=0): .. testcode:: + import os import numpy as np from megengine import tensor import megengine.functional as F - x = tensor(np.random.random((2,3,4,5)), dtype=np.float32) - out = F.split(x, 2, axis=3) - print(out[0].numpy().shape, out[1].numpy().shape) + x = tensor(np.random.random((10, 20)), dtype=np.float32) + y = F.split(x, 3) + z = F.split(x, [6, 17], axis=1) + + if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"): + print([tuple(i.shape.numpy().tolist()) for i in y]) + print([tuple(i.shape.numpy().tolist()) for i in z]) + else: + print([i.shape for i in y]) + print([i.shape for i in z]) Outputs: .. testoutput:: - (2, 3, 4, 3) (2, 3, 4, 2) + [(4, 20), (3, 20), (3, 20)] + [(10, 6), (10, 11), (10, 3)] """ - sub_tensors = [] - sections = [] - - def swapaxis(inp, src, dst): - if src == dst: - return inp - shape = [i for i in range(inp.ndim)] - shape[src] = dst - shape[dst] = src - return inp.transpose(shape) - - inp = swapaxis(inp, 0, axis) - - if isinstance(nsplits_or_sections, int): - incr_step = ceil(inp.shape[0] / nsplits_or_sections) - nsplits = nsplits_or_sections - while nsplits > 0: - nsplits -= 1 - sections.append(incr_step.astype("int32")) - incr_step += nsplits_or_sections - else: - sections = nsplits_or_sections + ndim = len(inp.shape) + if axis >= ndim: + raise ValueError("Invalid axis {}".format(axis)) - st = 0 - for se in sections: - sub_tensors.append(swapaxis(inp[st:se], axis, 0)) - st = se + Ntotal = inp.shape[axis] - if st < inp.shape[0]: - sub_tensors.append(swapaxis(inp[st:], axis, 0)) + try: + Nsections = len(nsplits_or_sections) + 1 + is_array = True + except TypeError: + Nsections = int(nsplits_or_sections) + is_array = False + + if is_array: + div_points = [0] + list(nsplits_or_sections) + [Ntotal] + for i in range(1, len(div_points)): + if div_points[i - 1] >= div_points[i]: + raise ValueError( + "Invalid nsplits_or_secions: {}".format(nsplits_or_sections) + ) + else: # scalar + if Nsections <= 0: + raise ValueError("Number sections must be larger than 0") + if Nsections > Ntotal: + raise ValueError( + "The size {} at dim {} cannot be split into {} sections".format( + Ntotal, axis, Nsections + ) + ) + div_points = [0] + [ + floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) + ] + for i in range(2, Nsections + 1): + div_points[i] = div_points[i - 1] + div_points[i] + sub_tensors = [] + for i in range(Nsections): + l = div_points[i] + r = div_points[i + 1] + slices = tuple( + [slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1) + ) + sub_tensors.append(inp[slices]) return sub_tensors diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 4ab424eb5..235f19754 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -77,14 +77,34 @@ def test_stack(): def test_split(): data = np.random.random((2, 3, 4, 5)).astype(np.float32) - mge_out1 = F.split(tensor(data), 2, axis=3) - mge_out2 = F.split(tensor(data), [3, 5], axis=3) + inp = tensor(data) + + mge_out0 = F.split(inp, 2, axis=3) + mge_out1 = F.split(inp, [3], axis=3) np_out = np.split(data, [3, 5], axis=3) - np.testing.assert_equal(mge_out1[0].numpy(), mge_out2[0].numpy()) + assert len(mge_out0) == 2 + assert len(mge_out1) == 2 + + np.testing.assert_equal(mge_out0[0].numpy(), np_out[0]) np.testing.assert_equal(mge_out1[0].numpy(), np_out[0]) + np.testing.assert_equal(mge_out0[1].numpy(), np_out[1]) + np.testing.assert_equal(mge_out1[1].numpy(), np_out[1]) + + try: + F.split(inp, 4) + assert False + except ValueError as e: + pass + + try: + F.split(inp, [3, 3, 5], axis=3) + assert False + except ValueError as e: + assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" + def test_reshape(): x = np.arange(6, dtype="float32") -- GitLab