From e788c7b5939ed822437fd0429218604ca1a60bf4 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 22 Nov 2021 16:58:02 +0800 Subject: [PATCH] Support zero value in dimension for slice (#37313) * support zero dim for slice op * support zero dim Tensor in set_value op * polish some debug log --- paddle/fluid/operators/set_value_op.h | 3 ++ paddle/fluid/operators/slice_utils.h | 33 +++++++++++++------ paddle/fluid/operators/strided_slice_op.h | 7 ++-- .../tests/unittests/test_set_value_op.py | 8 +++++ .../fluid/tests/unittests/test_var_base.py | 6 +++- 5 files changed, 42 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/set_value_op.h b/paddle/fluid/operators/set_value_op.h index 72b94dfa772..71eb0389540 100644 --- a/paddle/fluid/operators/set_value_op.h +++ b/paddle/fluid/operators/set_value_op.h @@ -260,6 +260,9 @@ class SetValueKernel : public framework::OpKernel { starts_indices[axis_index] = starts[i]; ends_indices[axis_index] = ends[i]; strides_indices[axis_index] = steps[i]; + if (starts[i] == ends[i]) { // slice is empty, data will not be changed + return; + } } out_e.stridedSlice(starts_indices, ends_indices, strides_indices) diff --git a/paddle/fluid/operators/slice_utils.h b/paddle/fluid/operators/slice_utils.h index 290df94774b..fa36ded24f9 100644 --- a/paddle/fluid/operators/slice_utils.h +++ b/paddle/fluid/operators/slice_utils.h @@ -30,12 +30,20 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, std::vector* infer_flags = nullptr) { for (size_t i = 0; i < axes.size(); ++i) { T axis = axes[i]; + PADDLE_ENFORCE_LT( + axis, in_dims.size(), + platform::errors::InvalidArgument( + "The axis value should be less than the rank of input, " + "but received axes[%d] = %d, rank of input is %d.", + i, axis, in_dims.size())); + + if (infer_flags != nullptr && (*infer_flags)[i] == -1) { + continue; + } + T dim_value = in_dims[axis]; if (dim_value > 0) { - if (infer_flags != nullptr && (*infer_flags)[i] == -1) { - continue; - } T step = steps == nullptr ? 1 : (*steps)[i]; PADDLE_ENFORCE_NE( step, 0, platform::errors::InvalidArgument( @@ -51,7 +59,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, if (step > 0) { start = std::min(start, dim_value); end = std::max(end, static_cast(0)); - PADDLE_ENFORCE_GT( + PADDLE_ENFORCE_GE( end, start, platform::errors::InvalidArgument( "When step > 0, end should be greater than start, but " @@ -63,7 +71,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, // "end is -1" means contain the 0-th element of this axis. start = std::min(start, dim_value - 1); end = std::max(end, static_cast(-1)); - PADDLE_ENFORCE_GT( + PADDLE_ENFORCE_GE( start, end, platform::errors::InvalidArgument( "When step < 0, start should be greater than end, but " @@ -73,6 +81,9 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, (*starts)[i] = start; (*ends)[i] = end; + } else if (dim_value == 0) { + (*starts)[i] = 0; + (*ends)[i] = 0; } } } @@ -111,20 +122,22 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, const std::vector& decrease_axes, std::vector* infer_flags = nullptr) { framework::DDim decreased_dims(slice_dims); + std::vector decrease_flag(slice_dims.size(), 0); if (decrease_axes.size() > 0) { for (size_t i = 0; i < decrease_axes.size(); ++i) { T axis = decrease_axes[i]; + decrease_flag[axis] = 1; if (infer_flags && (*infer_flags)[i] != -1) { - PADDLE_ENFORCE_EQ( - decreased_dims[axis], 1, - platform::errors::InvalidArgument("decrease dim should be 1")); + PADDLE_ENFORCE_EQ(decreased_dims[axis], 1, + platform::errors::InvalidArgument( + "Decrease dim should be 1, but now received %d", + decreased_dims[axis])); } - decreased_dims[axis] = 0; } std::vector new_shape; for (int i = 0; i < decreased_dims.size(); ++i) { - if (decreased_dims[i] != 0) { + if (decrease_flag[i] == 0) { new_shape.push_back(decreased_dims[i]); } } diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h index e5b808174ac..9eae27cca68 100644 --- a/paddle/fluid/operators/strided_slice_op.h +++ b/paddle/fluid/operators/strided_slice_op.h @@ -77,10 +77,9 @@ static void StridedSliceOutDims( end_index = end_index + 1; } - bool zero_dim_condition = - ((stride_index < 0 && (start_index <= end_index)) || - (stride_index > 0 && (start_index >= end_index))); - PADDLE_ENFORCE_EQ(zero_dim_condition, false, + bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) || + (stride_index > 0 && (start_index > end_index))); + PADDLE_ENFORCE_EQ(neg_dim_condition, false, platform::errors::InvalidArgument( "The start index and end index are invalid for their " "corresponding stride.")); diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 057d1b590a0..76cdaff5949 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -127,6 +127,14 @@ class TestSetValueItemSlice4(TestSetValueApi): self.data[0:, 1:2, :] = self.value +class TestSetValueItemSlice5(TestSetValueApi): + def _call_setitem(self, x): + x[0:, 1:1, :] = self.value + + def _get_answer(self): + self.data[0:, 1:1, :] = self.value + + class TestSetValueItemSliceInWhile(TestSetValueApi): def _call_setitem(self, x): def cond(i, x): diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 3e7b14aa99a..0e50a20a04e 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -568,10 +568,12 @@ class TestVarBase(unittest.TestCase): var14 = var[1:-1, 0:2, ::-1] var15 = var[::-1, ::-1, ::-1] var16 = var[-4:4] + var17 = var[:, 0, 0:0] + var18 = var[:, 1:1:2] vars = [ var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10, - var11, var12, var13, var14, var15, var16 + var11, var12, var13, var14, var15, var16, var17, var18 ] local_out = [var.numpy() for var in vars] @@ -600,6 +602,8 @@ class TestVarBase(unittest.TestCase): self.assertTrue( np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4])) + self.assertTrue(np.array_equal(local_out[17], tensor_array[:, 0, 0:0])) + self.assertTrue(np.array_equal(local_out[18], tensor_array[:, 1:1:2])) def _test_slice_for_tensor_attr(self): tensor_array = np.array( -- GitLab