未验证 提交 cd54cfab 编写于 作者: W WangZhen 提交者: GitHub

Fix is_tensor_array in getitem (#50502)

上级 d84b918b
...@@ -341,6 +341,18 @@ def for_zip(x, y): ...@@ -341,6 +341,18 @@ def for_zip(x, y):
return x + y return x + y
@paddle.jit.to_static
def tensor_array_slice_in_enumerate():
feats = {}
feats['key'] = []
feats_idx = paddle.arange(0, 10)
for i, idx in enumerate(feats_idx):
if i > 1:
feat_n2 = feats['key'][-2]
feats['key'].append(idx)
return feat_n2
class TestTransformBase(unittest.TestCase): class TestTransformBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = ( self.place = (
...@@ -536,6 +548,14 @@ class TestForOriginalTuple(TestTransformForOriginalList): ...@@ -536,6 +548,14 @@ class TestForOriginalTuple(TestTransformForOriginalList):
self.transformed_result_compare() self.transformed_result_compare()
class TestSliceTensorArrayInEnumerate(TestTransformForOriginalList):
def set_test_func(self):
self.dygraph_func = tensor_array_slice_in_enumerate
def test_transformed_result_compare(self):
self.transformed_result_compare()
class TestForZip(unittest.TestCase): class TestForZip(unittest.TestCase):
def setUp(self): def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
......
...@@ -389,11 +389,11 @@ def _getitem_impl_(var, item): ...@@ -389,11 +389,11 @@ def _getitem_impl_(var, item):
slice_item slice_item
): ):
if ( if (
isinstance(slice_item, int) not is_tensor_array
and isinstance(slice_item, int)
and var.shape[dim] is not None and var.shape[dim] is not None
and var.shape[dim] >= 0 and var.shape[dim] >= 0
and slice_item >= var.shape[dim] and slice_item >= var.shape[dim]
and not is_tensor_array
): ):
# For python, if users write a, b = var, the __getitem__ # For python, if users write a, b = var, the __getitem__
# method will iterate through 0, 1, 2 ... until __getitem__ # method will iterate through 0, 1, 2 ... until __getitem__
...@@ -423,10 +423,10 @@ def _getitem_impl_(var, item): ...@@ -423,10 +423,10 @@ def _getitem_impl_(var, item):
if start is None: if start is None:
start = 0 if step > 0 else MAX_INTEGER start = 0 if step > 0 else MAX_INTEGER
if end is None: if end is None:
if var.shape[dim] != -1 and ( if (
paddle.fluid.framework._non_static_mode() paddle.fluid.framework._non_static_mode()
or not is_tensor_array or not is_tensor_array
): ) and var.shape[dim] != -1:
end = var.shape[dim] if step > 0 else -1 end = var.shape[dim] if step > 0 else -1
else: else:
end = MAX_INTEGER if step > 0 else -1 end = MAX_INTEGER if step > 0 else -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册