未验证 提交 fccf664f 编写于 作者: X xiongkun 提交者: GitHub

[BugFix] fix tensor_array slice bugs in _getitem_impl_ (#46447)

* fix tensor_array slice bugs in _getitem_impl_

* fix when var is a paddle.Tensor

* code format
上级 97004f67
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
......@@ -124,6 +125,14 @@ def test_list_append_in_while_loop_with_stack(x, iter_num):
return out
def test_tensor_array_slice(x, iter_num):
a = []
for i in range(paddle.to_tensor(3)):
a.append(paddle.to_tensor(i))
t = a[1:3]
return a[2]
# Situation 2: Test list pop
def test_list_pop_without_control_flow_1(x):
x = fluid.dygraph.to_variable(x)
......@@ -292,6 +301,11 @@ class TestListInWhileLoopWithStack(TestListInWhileLoop):
self.all_dygraph_funcs = [test_list_append_in_while_loop_with_stack]
class TestTensorArraySlice(TestListInWhileLoop):
def init_dygraph_func(self):
self.all_dygraph_funcs = [test_tensor_array_slice]
class TestListInForLoop(TestListInWhileLoop):
def init_dygraph_func(self):
self.all_dygraph_funcs = [
......
......@@ -380,6 +380,10 @@ def _getitem_impl_(var, item):
item = replace_ellipsis(var, item)
item, none_axes = replace_none(item)
slice_info = SliceInfo()
is_tensor_array = (
hasattr(var, "desc")
and var.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
)
for dim, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor(
......@@ -390,13 +394,13 @@ def _getitem_impl_(var, item):
and var.shape[dim] is not None
and var.shape[dim] >= 0
and slice_item >= var.shape[dim]
and not is_tensor_array
):
# For python, if users write a, b = var, the __getitem__
# method will iterate through 0, 1, 2 ... until __getitem__
# throws an IndexError, then stop. The var[0], var[1] will
# be given to a, b respectively. If more values are given,
# the unpack size would cause error.
#
# We raises IndexError here to support grammar like `a, b = var`
raise IndexError(
"slice_item %d at dim %d should be >= 0 and < var.shape[%d]: %d"
......@@ -422,7 +426,7 @@ def _getitem_impl_(var, item):
if end is None:
if var.shape[dim] != -1 and (
paddle.fluid.framework._non_static_mode()
or var.desc.type() != core.VarDesc.VarType.LOD_TENSOR_ARRAY
or not is_tensor_array
):
end = var.shape[dim] if step > 0 else -1
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册