From cd54cfab7f785fa2912dc63d8acd21e2df050ad7 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Wed, 15 Feb 2023 10:45:50 +0800 Subject: [PATCH] Fix is_tensor_array in getitem (#50502) --- .../dygraph_to_static/test_for_enumerate.py | 20 +++++++++++++++++++ python/paddle/fluid/variable_index.py | 8 ++++---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py index 42704cfe289..e7dcfb68491 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -341,6 +341,18 @@ def for_zip(x, y): return x + y +@paddle.jit.to_static +def tensor_array_slice_in_enumerate(): + feats = {} + feats['key'] = [] + feats_idx = paddle.arange(0, 10) + for i, idx in enumerate(feats_idx): + if i > 1: + feat_n2 = feats['key'][-2] + feats['key'].append(idx) + return feat_n2 + + class TestTransformBase(unittest.TestCase): def setUp(self): self.place = ( @@ -536,6 +548,14 @@ class TestForOriginalTuple(TestTransformForOriginalList): self.transformed_result_compare() +class TestSliceTensorArrayInEnumerate(TestTransformForOriginalList): + def set_test_func(self): + self.dygraph_func = tensor_array_slice_in_enumerate + + def test_transformed_result_compare(self): + self.transformed_result_compare() + + class TestForZip(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index f1c1e7dfb41..93f6a21b7b7 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -389,11 +389,11 @@ def _getitem_impl_(var, item): slice_item ): if ( - isinstance(slice_item, int) + not is_tensor_array + and isinstance(slice_item, int) and var.shape[dim] is not None and var.shape[dim] >= 0 and slice_item >= var.shape[dim] - and not is_tensor_array ): # For python, if users write a, b = var, the __getitem__ # method will iterate through 0, 1, 2 ... until __getitem__ @@ -423,10 +423,10 @@ def _getitem_impl_(var, item): if start is None: start = 0 if step > 0 else MAX_INTEGER if end is None: - if var.shape[dim] != -1 and ( + if ( paddle.fluid.framework._non_static_mode() or not is_tensor_array - ): + ) and var.shape[dim] != -1: end = var.shape[dim] if step > 0 else -1 else: end = MAX_INTEGER if step > 0 else -1 -- GitLab