未验证 提交 aa50868f 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Add Support for a, b = static_variable Grammar (#33499)

For python, if users write `a, b = var`, the `__getitem__` method will iterate through 0, 1, 2 ... until `__getitem__` throws an IndexError, then stop. The var[0], var[1] will be given to a, b respectively. If more values are given, the unpack size would cause error.

We didn't raise the IndexError in the past and we add statement in `__getitem__` to raise IndexError here to support grammar like `a, b = var` in this PR.
上级 71f8707b
......@@ -85,6 +85,13 @@ def dyfunc_tuple_shape_2(x):
return res
def dyfunc_tuple_shape_3(x):
x = paddle.to_tensor(x)
a, b = paddle.shape(x)
res = paddle.reshape(x, shape=(b, a))
return res
def dyfunc_paddle_shape_api(x):
x = paddle.to_tensor(x)
# paddle.shape will not be converted.
......@@ -337,6 +344,18 @@ class TestTupleShape2(TestTensorShapeBasic):
self.expected_slice_op_num = 2
class TestTupleShape3(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32")
self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")]
self.dygraph_func = dyfunc_tuple_shape_3
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_shape_op_num = 1
self.expected_slice_op_num = 2
class TestPaddleShapeApi(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32")
......
......@@ -116,6 +116,19 @@ def _getitem_impl_(var, item):
for dim, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item):
if isinstance(slice_item,
int) and var.shape[dim] is not None and var.shape[
dim] >= 0 and slice_item >= var.shape[dim]:
# For python, if users write a, b = var, the __getitem__
# method will iterate through 0, 1, 2 ... until __getitem__
# throws an IndexError, then stop. The var[0], var[1] will
# be given to a, b respectively. If more values are given,
# the unpack size would cause error.
#
# We raises IndexError here to support grammar like `a, b = var`
raise IndexError(
"slice_item %d at dim %d should be >= 0 and < var.shape[%d]: %d"
% (slice_item, dim, dim, var.shape[dim]))
decrease_axes.append(dim)
start = slice_item
step = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册