未验证 提交 f2a56c6a 编写于 作者: Z zyfncg 提交者: GitHub

fix bug of indexing with ellipsis (#37182)

上级 10cc040d
...@@ -549,13 +549,20 @@ static void ParseIndexingSlice( ...@@ -549,13 +549,20 @@ static void ParseIndexingSlice(
// specified_dims is the number of dimensions which indexed by Interger, // specified_dims is the number of dimensions which indexed by Interger,
// Slices. // Slices.
int specified_dims = 0; int specified_dims = 0;
int ell_count = 0;
for (int dim = 0; dim < size; ++dim) { for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index, dim); PyObject *slice_item = PyTuple_GetItem(index, dim);
if (PyCheckInteger(slice_item) || PySlice_Check(slice_item)) { if (PyCheckInteger(slice_item) || PySlice_Check(slice_item)) {
specified_dims++; specified_dims++;
} else if (slice_item == Py_Ellipsis) {
ell_count++;
} }
} }
PADDLE_ENFORCE_LE(ell_count, 1,
platform::errors::InvalidArgument(
"An index can only have a single ellipsis ('...')"));
for (int i = 0, dim = 0; i < size; ++i) { for (int i = 0, dim = 0; i < size; ++i) {
PyObject *slice_item = PyTuple_GetItem(index, i); PyObject *slice_item = PyTuple_GetItem(index, i);
...@@ -660,7 +667,7 @@ static void ParseIndexingSlice( ...@@ -660,7 +667,7 @@ static void ParseIndexingSlice(
} }
// valid_index is the number of dimensions exclude None index // valid_index is the number of dimensions exclude None index
const int valid_indexs = size - none_axes->size(); const int valid_indexs = size - none_axes->size() - ell_count;
PADDLE_ENFORCE_EQ(valid_indexs <= rank, true, PADDLE_ENFORCE_EQ(valid_indexs <= rank, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Too many indices (%d) for tensor of dimension %d.", "Too many indices (%d) for tensor of dimension %d.",
......
...@@ -702,6 +702,11 @@ class TestVarBase(unittest.TestCase): ...@@ -702,6 +702,11 @@ class TestVarBase(unittest.TestCase):
assert_getitem_ellipsis_index(var_fp32, np_fp32_value) assert_getitem_ellipsis_index(var_fp32, np_fp32_value)
assert_getitem_ellipsis_index(var_int, np_int_value) assert_getitem_ellipsis_index(var_int, np_int_value)
# test 1 dim tensor
var_one_dim = paddle.to_tensor([1, 2, 3, 4])
self.assertTrue(
np.array_equal(var_one_dim[..., 0].numpy(), np.array([1])))
def _test_none_index(self): def _test_none_index(self):
shape = (8, 64, 5, 256) shape = (8, 64, 5, 256)
np_value = np.random.random(shape).astype('float32') np_value = np.random.random(shape).astype('float32')
......
...@@ -226,19 +226,22 @@ class TestVariable(unittest.TestCase): ...@@ -226,19 +226,22 @@ class TestVariable(unittest.TestCase):
prog = paddle.static.Program() prog = paddle.static.Program()
with paddle.static.program_guard(prog): with paddle.static.program_guard(prog):
x = paddle.assign(data) x = paddle.assign(data)
y = paddle.assign([1, 2, 3, 4])
out1 = x[0:, ..., 1:] out1 = x[0:, ..., 1:]
out2 = x[0:, ...] out2 = x[0:, ...]
out3 = x[..., 1:] out3 = x[..., 1:]
out4 = x[...] out4 = x[...]
out5 = x[[1, 0], [0, 0]] out5 = x[[1, 0], [0, 0]]
out6 = x[([1, 0], [0, 0])] out6 = x[([1, 0], [0, 0])]
out7 = y[..., 0]
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out1, out2, out3, out4, out5, out6]) result = exe.run(prog,
fetch_list=[out1, out2, out3, out4, out5, out6, out7])
expected = [ expected = [
data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...], data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...],
data[[1, 0], [0, 0]], data[([1, 0], [0, 0])] data[[1, 0], [0, 0]], data[([1, 0], [0, 0])], np.array([1])
] ]
self.assertTrue((result[0] == expected[0]).all()) self.assertTrue((result[0] == expected[0]).all())
...@@ -247,6 +250,7 @@ class TestVariable(unittest.TestCase): ...@@ -247,6 +250,7 @@ class TestVariable(unittest.TestCase):
self.assertTrue((result[3] == expected[3]).all()) self.assertTrue((result[3] == expected[3]).all())
self.assertTrue((result[4] == expected[4]).all()) self.assertTrue((result[4] == expected[4]).all())
self.assertTrue((result[5] == expected[5]).all()) self.assertTrue((result[5] == expected[5]).all())
self.assertTrue((result[6] == expected[6]).all())
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
res = x[[1.2, 0]] res = x[[1.2, 0]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册