From aa50868fc11c05b130355c9335eab79e8acb3824 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Fri, 11 Jun 2021 13:23:56 +0800 Subject: [PATCH] [Dy2stat] Add Support for a, b = static_variable Grammar (#33499) For python, if users write `a, b = var`, the `__getitem__` method will iterate through 0, 1, 2 ... until `__getitem__` throws an IndexError, then stop. The var[0], var[1] will be given to a, b respectively. If more values are given, the unpack size would cause error. We didn't raise the IndexError in the past and we add statement in `__getitem__` to raise IndexError here to support grammar like `a, b = var` in this PR. --- .../dygraph_to_static/test_tensor_shape.py | 19 +++++++++++++++++++ python/paddle/fluid/variable_index.py | 13 +++++++++++++ 2 files changed, 32 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index ace49db107..f7cdb12a1a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -85,6 +85,13 @@ def dyfunc_tuple_shape_2(x): return res +def dyfunc_tuple_shape_3(x): + x = paddle.to_tensor(x) + a, b = paddle.shape(x) + res = paddle.reshape(x, shape=(b, a)) + return res + + def dyfunc_paddle_shape_api(x): x = paddle.to_tensor(x) # paddle.shape will not be converted. @@ -337,6 +344,18 @@ class TestTupleShape2(TestTensorShapeBasic): self.expected_slice_op_num = 2 +class TestTupleShape3(TestTensorShapeBasic): + def init_test_func(self): + self.input = numpy.ones((5, 7)).astype("int32") + self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")] + self.dygraph_func = dyfunc_tuple_shape_3 + + def _set_expected_op_num(self): + self.expected_op_num = 5 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 2 + + class TestPaddleShapeApi(TestTensorShapeBasic): def init_test_func(self): self.input = numpy.ones((5, 7)).astype("int32") diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index c6ddba7fea..c9363dff13 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -116,6 +116,19 @@ def _getitem_impl_(var, item): for dim, slice_item in enumerate(item): if is_integer_or_scalar_tensor(slice_item): + if isinstance(slice_item, + int) and var.shape[dim] is not None and var.shape[ + dim] >= 0 and slice_item >= var.shape[dim]: + # For python, if users write a, b = var, the __getitem__ + # method will iterate through 0, 1, 2 ... until __getitem__ + # throws an IndexError, then stop. The var[0], var[1] will + # be given to a, b respectively. If more values are given, + # the unpack size would cause error. + # + # We raises IndexError here to support grammar like `a, b = var` + raise IndexError( + "slice_item %d at dim %d should be >= 0 and < var.shape[%d]: %d" + % (slice_item, dim, dim, var.shape[dim])) decrease_axes.append(dim) start = slice_item step = 1 -- GitLab