From fccf664f38cd52183ab1b2b2213c5248599cb323 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 8 Nov 2022 14:47:47 +0800 Subject: [PATCH] [BugFix] fix tensor_array slice bugs in _getitem_impl_ (#46447) * fix tensor_array slice bugs in _getitem_impl_ * fix when var is a paddle.Tensor * code format --- .../tests/unittests/dygraph_to_static/test_list.py | 14 ++++++++++++++ python/paddle/fluid/variable_index.py | 8 ++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py index 08259d91cc..4f6984c5e0 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import unittest import paddle @@ -124,6 +125,14 @@ def test_list_append_in_while_loop_with_stack(x, iter_num): return out +def test_tensor_array_slice(x, iter_num): + a = [] + for i in range(paddle.to_tensor(3)): + a.append(paddle.to_tensor(i)) + t = a[1:3] + return a[2] + + # Situation 2: Test list pop def test_list_pop_without_control_flow_1(x): x = fluid.dygraph.to_variable(x) @@ -292,6 +301,11 @@ class TestListInWhileLoopWithStack(TestListInWhileLoop): self.all_dygraph_funcs = [test_list_append_in_while_loop_with_stack] +class TestTensorArraySlice(TestListInWhileLoop): + def init_dygraph_func(self): + self.all_dygraph_funcs = [test_tensor_array_slice] + + class TestListInForLoop(TestListInWhileLoop): def init_dygraph_func(self): self.all_dygraph_funcs = [ diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index f3c7fff38f..552fb7a9aa 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -380,6 +380,10 @@ def _getitem_impl_(var, item): item = replace_ellipsis(var, item) item, none_axes = replace_none(item) slice_info = SliceInfo() + is_tensor_array = ( + hasattr(var, "desc") + and var.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY + ) for dim, slice_item in enumerate(item): if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor( @@ -390,13 +394,13 @@ def _getitem_impl_(var, item): and var.shape[dim] is not None and var.shape[dim] >= 0 and slice_item >= var.shape[dim] + and not is_tensor_array ): # For python, if users write a, b = var, the __getitem__ # method will iterate through 0, 1, 2 ... until __getitem__ # throws an IndexError, then stop. The var[0], var[1] will # be given to a, b respectively. If more values are given, # the unpack size would cause error. - # # We raises IndexError here to support grammar like `a, b = var` raise IndexError( "slice_item %d at dim %d should be >= 0 and < var.shape[%d]: %d" @@ -422,7 +426,7 @@ def _getitem_impl_(var, item): if end is None: if var.shape[dim] != -1 and ( paddle.fluid.framework._non_static_mode() - or var.desc.type() != core.VarDesc.VarType.LOD_TENSOR_ARRAY + or not is_tensor_array ): end = var.shape[dim] if step > 0 else -1 else: -- GitLab