未验证 提交 cd54cfab 编写于 作者: W WangZhen 提交者: GitHub

Fix is_tensor_array in getitem (#50502)

上级 d84b918b
......@@ -341,6 +341,18 @@ def for_zip(x, y):
return x + y
@paddle.jit.to_static
def tensor_array_slice_in_enumerate():
feats = {}
feats['key'] = []
feats_idx = paddle.arange(0, 10)
for i, idx in enumerate(feats_idx):
if i > 1:
feat_n2 = feats['key'][-2]
feats['key'].append(idx)
return feat_n2
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = (
......@@ -536,6 +548,14 @@ class TestForOriginalTuple(TestTransformForOriginalList):
self.transformed_result_compare()
class TestSliceTensorArrayInEnumerate(TestTransformForOriginalList):
def set_test_func(self):
self.dygraph_func = tensor_array_slice_in_enumerate
def test_transformed_result_compare(self):
self.transformed_result_compare()
class TestForZip(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
......
......@@ -389,11 +389,11 @@ def _getitem_impl_(var, item):
slice_item
):
if (
isinstance(slice_item, int)
not is_tensor_array
and isinstance(slice_item, int)
and var.shape[dim] is not None
and var.shape[dim] >= 0
and slice_item >= var.shape[dim]
and not is_tensor_array
):
# For python, if users write a, b = var, the __getitem__
# method will iterate through 0, 1, 2 ... until __getitem__
......@@ -423,10 +423,10 @@ def _getitem_impl_(var, item):
if start is None:
start = 0 if step > 0 else MAX_INTEGER
if end is None:
if var.shape[dim] != -1 and (
if (
paddle.fluid.framework._non_static_mode()
or not is_tensor_array
):
) and var.shape[dim] != -1:
end = var.shape[dim] if step > 0 else -1
else:
end = MAX_INTEGER if step > 0 else -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册