From 43fb2c4afba76cd3256eecbbf2a4f357dd893275 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 14 Sep 2021 16:14:26 +0800 Subject: [PATCH] feat(opr): let roll support empty IO GitOrigin-RevId: b9a59b623a8b16ca0a6af1340cfc226b73128321 --- .../python/megengine/functional/tensor.py | 7 +++--- .../test/unit/functional/test_tensor.py | 23 +++++++++++++++++++ src/opr/impl/tensor_manip.cpp | 6 +++-- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 67aa9f4a..a881fab1 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -1352,10 +1352,11 @@ def roll( if shift_ == 0: continue size = shp[axis_normalized_] - if shift_ > 0: - a, b = split(out, [size - shift_,], axis=axis_normalized_) + shift_normalized_ = 0 if size == 0 else shift_ % size + if shift_normalized_ > 0: + a, b = split(out, [size - shift_normalized_,], axis=axis_normalized_) else: - a, b = split(out, [-shift_,], axis=axis_normalized_) + a, b = split(out, [-shift_normalized_,], axis=axis_normalized_) out = concat((b, a), axis=axis_normalized_) if shp_bak is not None: out = out.reshape(shp_bak) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index e3423a49..aa609dcd 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -806,6 +806,8 @@ def test_tile(shape, reps, is_varnode): [ ((2, 3), 0, None), ((2, 3), 1, 0), + ((2, 3), 100, 0), + ((2, 3), -100, 0), ((2, 3, 4, 5), (-1, 1), (0, 1)), ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)), ], @@ -829,3 +831,24 @@ def test_roll(shape, shifts, axis, is_varnode): opr_test( cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network ) + + +@pytest.mark.parametrize( + "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),], +) +@pytest.mark.parametrize("is_symbolic", [None, True, False]) +def test_roll_empty_tensor(shape, shifts, axis, is_symbolic): + inp = tensor(np.random.randn(*shape).astype("float32")) + + def func(inp): + return F.roll(inp, shifts, axis) + + if is_symbolic is not None: + func = trace(symbolic=is_symbolic)(func) + + out_ref = np.roll(inp.numpy(), shifts, axis) + for _ in range(3): + out = F.roll(inp, shifts, axis) + np.testing.assert_equal(out.numpy(), out_ref) + if is_symbolic is None: + break diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index 11d56587..6df7b518 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -1339,8 +1339,10 @@ void Concat::scn_do_execute() { if (real_axis < 0) real_axis += in.shape().ndim; end = begin + in.shape().shape[real_axis]; - out.sub(Slice(begin, end).apply(out.layout(), real_axis)). - copy_from_fixlayout(in); + if (!in.layout().is_empty()) { + out.sub(Slice(begin, end).apply(out.layout(), real_axis)). + copy_from_fixlayout(in); + } } } -- GitLab