未验证 提交 697c712f 编写于 作者: J JYChen 提交者: GitHub

Support Combined indexing for __getitem__ and __setitem__ (#55211)

* WIP: start writing combined indexing get

* list/tuple/Variable

* getitem 80%

* add setitem

* add some unittest for setitem

* lazy import

* fix some setitem error

* fix advance indexing with decreasing axes; fix strided_slice input name

* combine int-tensor getitem is ok (without boolean support & broadcast); add getitem unittest for static

* add broadcast & parse bool tensor for __getitem

* [change getitem] _getitem_impl_ to _getitem_static, not deleting the former one

* refine new getitem; fix ut in variable/var_base

* add __getitem__ ut in dygraph

* re-dispatch getitem for Py/CPP; fix strided_slice decrease axes error in dygraph

* fix ut; support tensor in slice

* [change setitem] _setitem_impl_ to _setitem_static, not deleting the former one

* remove some UT (for some, temporarily)

* add IndexError to solve timeout problem in static-mode

* 1.temply forbideen all-False bool-indexput; 2.setitem_static will return new variable

* xpu uses old stratege

* rename dy2st setitem ut to avoid same-name problem

* dy2st for new combined index

* ut case for combine-index with dy2st

* open ut with all-false-bool setitem

* remove useless doc and _getitem_impl_

* change static res

* fix static xpu
上级 81511469
......@@ -1120,6 +1120,9 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
eager_gil_scoped_release guard;
out = strided_slice_ad_func(
self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
if (!decrease_axis_tmp.empty()) {
out = squeeze_ad_func(out, decrease_axis_tmp);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Slice is only support slice and strided_slice, but we got %s which "
......
......@@ -26,7 +26,8 @@ from .. import unique_name
from ..framework import (
Variable,
Parameter,
_getitem_impl_,
_getitem_static,
_setitem_static,
_setitem_impl_,
EagerParamBase,
in_dygraph_mode,
......@@ -726,47 +727,34 @@ def monkey_patch_tensor():
return True
return False
def __getitem__(self, item):
def is_list_tuple(index, contain_type):
def _is_list_tuple(item):
if isinstance(item, (tuple, list)):
for s in item:
if not _is_list_tuple(s):
return False
else:
if type(item) != contain_type:
return False
def contain_tensor_or_list(item):
if not isinstance(item, tuple):
item = (item,)
for slice_item in item:
if isinstance(slice_item, (list, np.ndarray, Variable)):
return True
elif isinstance(slice_item, slice):
if (
isinstance(slice_item.start, Variable)
or isinstance(slice_item.stop, Variable)
or isinstance(slice_item.step, Variable)
):
return True
if not isinstance(index, (tuple, list)):
return False
for s in index:
if not _is_list_tuple(s):
return False
return True
if contain_tensor(item) or is_list_tuple(item, int):
def __getitem__(self, item):
if contain_tensor_or_list(item):
# 1. Call _getitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return _getitem_impl_(self, item)
return _getitem_static(self, item)
else:
# 2. Call c++ func getitem_index_not_tensor to speedup.
return self._getitem_index_not_tensor(item)
def __setitem__(self, item, value):
def contain_tensor_or_list(item):
if not isinstance(item, tuple):
item = [item]
for slice_item in item:
if isinstance(slice_item, list):
return True
elif isinstance(slice_item, Variable):
return True
return False
def is_combine_index(item):
var_type = None
item_type = None
......@@ -788,10 +776,13 @@ def monkey_patch_tensor():
return False
if contain_tensor_or_list(item) and not is_combine_index(item):
# To reuse code with static graph,
# Call _setitem_impl_ when item contains tensor or list.
if contain_tensor_or_list(item):
if core.is_compiled_with_xpu() and not is_combine_index(item):
# (NOTE): Currently, there is no index_put_xpu kernel.
return _setitem_impl_(self, item, value)
# To reuse code with static graph,
# Call _setitem_static when item contains tensor or list.
return _setitem_static(self, item, value)
else:
return self.__setitem_eager_tensor__(item, value)
......
......@@ -37,7 +37,7 @@ from . import unique_name
import paddle.version as fluid_version
import warnings
import functools
from .variable_index import _getitem_impl_, _setitem_impl_
from .variable_index import _getitem_static, _setitem_static, _setitem_impl_
import threading
__all__ = [
......@@ -2294,13 +2294,16 @@ class Variable(metaclass=VariableMetaClass):
raise IndexError("Valid index accept int or slice or tuple")
def __getitem__(self, item):
return _getitem_impl_(self, item)
return _getitem_static(self, item)
def __setitem__(self, item, value):
from .dygraph.base import in_declarative_mode
if in_declarative_mode():
if is_compiled_with_xpu():
# (NOTE): Currently, there is no index_put_xpu kernel.
return _setitem_impl_(self, item, value)
return _setitem_static(self, item, value)
else:
raise RuntimeError(
"In static mode, the __setitem__ (looks like: x[indices] = values) should not be used. Please use x = paddle.static.setitem(x, indices, values)"
......
......@@ -18,7 +18,7 @@ from paddle.fluid.data_feeder import check_type
from paddle.fluid.framework import convert_np_dtype_to_dtype_, static_only
from paddle.fluid.layer_helper import LayerHelper
from ..fluid.variable_index import _setitem_impl_
from ..fluid.variable_index import _setitem_impl_, _setitem_static
__all__ = []
......@@ -367,5 +367,8 @@ def setitem(x, index, value):
(1) a[Tensor([10,10])]=v -> setitem(a, (Tensor([10,10]),), v)
(2) a[1] = v -> setitem(a, (1,), v)
"""
if core.is_compiled_with_xpu():
# (NOTE): Currently, there is no index_put_xpu kernel.
return _setitem_impl_(x, index, value)
else:
return _setitem_static(x, index, value)
......@@ -127,6 +127,7 @@ if(WITH_TESTING)
add_subdirectory(ipu)
endif()
add_subdirectory(ir)
add_subdirectory(indexing)
add_subdirectory(legacy_test)
if(WITH_MKLDNN)
add_subdirectory(mkldnn)
......
......@@ -178,5 +178,24 @@ class TestCase11(TestSetItemBase):
return y, x_grad, value_grad
class TestCase12(TestSetItemBase):
# Test combind-indexing
def init_func(self):
def foo(x, value):
y = x + 1
y[[0, 1], 1, :2] = value
return y
return foo
def run_dygrah(self, func):
x = self.init_data()
value = paddle.ones((32,))
value.stop_gradient = False
y = func(x, value)
x_grad, value_grad = paddle.grad(y, [x, value])
return y, x_grad, value_grad
if __name__ == '__main__':
unittest.main()
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid.variable_index import _getitem_static
class TestGetitemInDygraph(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_combined_index_1(self):
# int tensor + slice (without decreasing axes)
np_data = np.random.randn(3, 4, 5, 6)
np_res = np_data[[0, 1], :, [1, 2]]
x = paddle.to_tensor(np_data)
y = x[[0, 1], :, [1, 2]]
np.testing.assert_allclose(y.numpy(), np_res)
def test_combined_index_2(self):
# int tensor + slice (with decreasing axes)
np_data = np.random.randn(3, 4, 5, 6)
x = paddle.to_tensor(np_data)
np_res = np_data[:, 1, [1, 2], 0]
y = x[:, 1, [1, 2], 0]
np.testing.assert_allclose(y.numpy(), np_res)
def test_combined_index_3(self):
# multiple int tensors, with one int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[[1, 0], :, [1, 4], 1:5:2, 4]
x = paddle.to_tensor(np_data)
y = x[[1, 0], :, [1, 4], 1:5:2, 4]
np.testing.assert_allclose(y.numpy(), np_res)
def test_combined_index_4(self):
# multiple not adjacent int tensors, with no int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[:, [1, 0], 0:4:2, [2, 3], 4]
x = paddle.to_tensor(np_data)
y = x[:, [1, 0], 0:4:2, [2, 3], 4]
np.testing.assert_allclose(y.numpy(), np_res)
def test_combined_index_5(self):
# multiple adjacent int tensors, with no int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[::2, [1, 0], [2, 3], 0:4:2]
x = paddle.to_tensor(np_data)
y = x[::2, [1, 0], [2, 3], 0:4:2]
np.testing.assert_allclose(y.numpy(), np_res)
def test_combined_index_6(self):
# multiple adjacent and not adjacent int tensors, with no int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[::2, [1, 0], [2, 3], 0:4:2, [4, 6]]
x = paddle.to_tensor(np_data)
y = x[::2, [1, 0], [2, 3], 0:4:2, [4, 6]]
np.testing.assert_allclose(y.numpy(), np_res)
def test_combined_index_7(self):
# multiple adjacent and not adjacent int tensors (rank > 1d), with no int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]]
x = paddle.to_tensor(np_data)
y = x[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]]
np.testing.assert_allclose(y.numpy(), np_res)
def test_combined_index_8(self):
# multiple adjacent and not adjacent int tensors (rank > 1d), with int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[
[[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]]
]
x = paddle.to_tensor(np_data)
y = x[[[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]]]
np.testing.assert_allclose(y.numpy(), np_res)
def test_combined_index_9(self):
# multiple int tensors, with broadcast.
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]]
x = paddle.to_tensor(np_data)
y = x[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]]
np.testing.assert_allclose(y.numpy(), np_res)
def test_combined_index_10(self):
# only one bool tensor with basic-index
np_data = np.random.randn(3, 4, 5, 6)
np_res = np_data[:, [True, False, True, False], 4]
x = paddle.to_tensor(np_data)
y = x[:, [True, False, True, False], 4]
np.testing.assert_allclose(y.numpy(), np_res)
def test_combined_index_11(self):
# only one bool tensor with all False
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[:, [False, False, False, False], 4]
x = paddle.to_tensor(np_data)
y = x[:, [False, False, False, False], 4]
np.testing.assert_allclose(y.numpy(), np_res)
class TestGetitemInStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.exe = paddle.static.Executor()
def test_combined_index_1(self):
# int tensor + slice (without decreasing axes)
np_data = np.random.randn(3, 4, 5, 6)
np_res = np_data[[0, 1], :, [1, 2]]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(x, ([0, 1], slice(None, None, None), [1, 2]))
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_combined_index_2(self):
# int tensor + slice (with decreasing axes)
np_data = np.random.randn(3, 4, 5, 6)
np_res = np_data[:, 1, [1, 2], 0]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(x, (slice(None, None, None), 1, [1, 2], 0))
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_combined_index_3(self):
# multiple int tensors, with one int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[[1, 0], :, [1, 4], 1:5:2, 4]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x, ([1, 0], slice(None, None, None), [1, 4], slice(1, 5, 2), 4)
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_combined_index_4(self):
# multiple not adjacent int tensors, with no int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[:, [1, 0], 0:4:2, [2, 3], 4]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x, (slice(None, None, None), [1, 0], slice(0, 4, 2), [2, 3], 4)
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_combined_index_5(self):
# multiple adjacent int tensors, with no int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[::2, [1, 0], [2, 3], 0:4:2]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x, (slice(None, None, 2), [1, 0], [2, 3], slice(0, 4, 2))
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_combined_index_6(self):
# multiple adjacent and not adjacent int tensors, with no int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[::2, [1, 0], [2, 3], 0:4:2, [4, 6]]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x,
(slice(None, None, 2), [1, 0], [2, 3], slice(0, 4, 2), [4, 6]),
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_combined_index_7(self):
# multiple adjacent and not adjacent int tensors (rank > 1d), with no int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x,
(
slice(None, None, 2),
[[1, 0]],
[[2, 3]],
slice(0, 4, 2),
[[4, 6]],
),
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_combined_index_8(self):
# multiple adjacent and not adjacent int tensors (rank > 1d), with int tensor at first axis
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[
[[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]]
]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x,
(
[[1, 0], [0, 1]],
[[2, 3], [1, 0]],
slice(0, 4, 2),
[[3, 5], [4, 2]],
),
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_combined_index_9(self):
# multiple int tensors, with broadcast.
np_data = np.random.randn(3, 4, 5, 6, 7)
np_res = np_data[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x, ([[1, 0]], [1, 0], slice(0, 4, 2), [[3, 5], [4, 2]])
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_combined_index_10(self):
# only one bool tensor with basic-index
np_data = np.random.randn(3, 4, 5, 6)
np_res = np_data[:, [True, False, True, False], 4]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x, (slice(None, None, None), [True, False, True, False], 4)
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_combined_index_11(self):
# only one bool tensor with all False
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[:, [False, False, False, False], 4]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x, (slice(None, None, None), [False, False, False, False], 4)
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
class TestGetItemErrorCase(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_bool_shape_error1(self):
x = paddle.randn((4, 3, 2))
with self.assertRaises(IndexError):
y = _getitem_static(x, ([True, False]))
def test_bool_shape_error2(self):
x = paddle.randn((4, 3, 2))
with self.assertRaises(IndexError):
y = _getitem_static(x, (1, paddle.to_tensor([True, False]), [0, 1]))
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid.variable_index import _setitem_static
class TestSetitemInDygraph(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_combined_index_1(self):
np_data = np.zeros((3, 4, 5, 6), dtype='float32')
x = paddle.to_tensor(np_data)
np_data[[0, 1], :, [1, 2]] = 10.0
x[[0, 1], :, [1, 2]] = 10.0
np.testing.assert_allclose(x.numpy(), np_data)
def test_combined_index_2(self):
np_data = np.ones((3, 4, 5, 6), dtype='float32')
x = paddle.to_tensor(np_data)
np_data[:, 1, [1, 2], 0] = 10.0
x[:, 1, [1, 2], 0] = 10.0
np.testing.assert_allclose(x.numpy(), np_data)
def test_combined_index_3(self):
np_data = np.ones((3, 4, 5, 6), dtype='int32')
x = paddle.to_tensor(np_data)
np_data[:, [True, False, True, False], [1, 4]] = 10
x[:, [True, False, True, False], [1, 4]] = 10
np.testing.assert_allclose(x.numpy(), np_data)
class TestSetitemInStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.exe = paddle.static.Executor()
def test_combined_index_1(self):
# int tensor + slice (without decreasing axes)
np_data = np.zeros((3, 4, 5, 6), dtype='float32')
np_data[[0, 1], :, [1, 2]] = 10.0
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.zeros((3, 4, 5, 6), dtype='float32')
y = _setitem_static(
x, ([0, 1], slice(None, None, None), [1, 2]), 10.0
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data)
def test_combined_index_2(self):
# int tensor + slice (with decreasing axes)
np_data = np.ones((3, 4, 5, 6), dtype='float32')
np_data[:, 1, [1, 2], 0] = 10.0
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.ones((3, 4, 5, 6), dtype='float32')
y = _setitem_static(
x, (slice(None, None, None), 1, [1, 2], 0), 10.0
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data)
def test_combined_index_3(self):
# int tensor + bool tensor + slice (without decreasing axes)
np_data = np.ones((3, 4, 5, 6), dtype='int32')
np_data[:, [True, False, True, False], [1, 4]] = 10
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.ones((3, 4, 5, 6), dtype='int32')
y = _setitem_static(
x,
(slice(None, None, None), [True, False, True, False], [1, 4]),
10,
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data)
def test_combined_index_4(self):
# int tensor (with ranks > 1) + bool tensor + slice (with decreasing axes)
np_data = np.ones((3, 4, 5, 6), dtype='int32')
np_data[[0, 0], [True, False, True, False], [[0, 2], [1, 4]], 4] = 16
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.ones((3, 4, 5, 6), dtype='int32')
y = _setitem_static(
x,
([0, 0], [True, False, True, False], [[0, 2], [1, 4]], 4),
16,
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data)
def test_combined_index_5(self):
# int tensor + slice + Ellipsis
np_data = np.ones((3, 4, 5, 6), dtype='int32')
np_data[..., [1, 4, 3], ::2] = 5
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.ones((3, 4, 5, 6), dtype='int32')
y = _setitem_static(
x,
(..., [1, 4, 3], slice(None, None, 2)),
5,
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data)
......@@ -1343,13 +1343,6 @@ class TestError(TestSetValueBase):
x[::one] = self.value
def _bool_list_error(self):
with self.assertRaises(TypeError):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
if paddle.in_dynamic_mode():
x[[True, False, 0]] = 0
else:
x = paddle.static.setitem(x, [True, False, 0], 0)
with self.assertRaises(IndexError):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
if paddle.in_dynamic_mode():
......@@ -1380,7 +1373,6 @@ class TestError(TestSetValueBase):
paddle.enable_static()
with paddle.static.program_guard(self.program):
self._value_type_error()
self._step_error()
self._bool_list_error()
self._bool_tensor_error()
self._broadcast_mismatch()
......
......@@ -944,11 +944,9 @@ class TestVarBase(unittest.TestCase):
var_tensor[var_tensor < 0.55], np_value[np_value < 0.55]
)
with self.assertRaises(ValueError):
var_tensor[[False, False, False, False]]
with self.assertRaises(ValueError):
with self.assertRaises(IndexError):
var_tensor[[True, False]]
with self.assertRaises(ValueError):
with self.assertRaises(IndexError):
var_tensor[[True, False, False, False, False]]
with self.assertRaises(IndexError):
var_tensor[paddle.to_tensor([[True, False, False, False]])]
......
......@@ -257,10 +257,6 @@ class TestVariable(unittest.TestCase):
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())
with self.assertRaises(IndexError):
one = paddle.ones(shape=[1])
res = x[one, [0, 0]]
def _test_slice_index_list(self, place):
data = np.random.rand(2, 3).astype("float32")
prog = paddle.static.Program()
......@@ -323,9 +319,6 @@ class TestVariable(unittest.TestCase):
self.assertTrue((result[5] == expected[5]).all())
self.assertTrue((result[6] == expected[6]).all())
with self.assertRaises(IndexError):
res = x[[1.2, 0]]
def _test_slice_index_list_bool(self, place):
data = np.random.rand(2, 3, 4).astype("float32")
np_idx = np.array([[True, False, False], [True, False, True]])
......@@ -375,9 +368,6 @@ class TestVariable(unittest.TestCase):
with self.assertRaises(IndexError):
res = x[[True, False, False]]
with self.assertRaises(ValueError):
with paddle.static.program_guard(prog):
res = x[[False, False]]
def _test_slice_index_scalar_bool(self, place):
data = np.random.rand(1, 3, 4).astype("float32")
......
......@@ -655,9 +655,9 @@ class TestApiWhileLoopSliceInBody(unittest.TestCase):
startup_program = Program()
with program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=[-1, 5], dtype='int32')
z = paddle.tensor.fill_constant([1], 'int32', 0)
z = paddle.tensor.fill_constant([], 'int32', 0)
x_shape = paddle.shape(x)
i = paddle.tensor.fill_constant([1], 'int32', 0)
i = paddle.tensor.fill_constant([], 'int32', 0)
z, _ = paddle.static.nn.while_loop(cond, body, [z, i])
place = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册