未验证 提交 79b9f47e 编写于 作者: Z zyfncg 提交者: GitHub

fix bug of indexing with ellipsis (#37192)

修复了一维Tensor在使用省略号(...)索引时维度检测异常的问题。
上级 dc873eba
......@@ -528,13 +528,20 @@ static void ParseIndexingSlice(
// specified_dims is the number of dimensions which indexed by Interger,
// Slices.
int specified_dims = 0;
int ell_count = 0;
for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index, dim);
if (PyCheckInteger(slice_item) || PySlice_Check(slice_item)) {
specified_dims++;
} else if (slice_item == Py_Ellipsis) {
ell_count++;
}
}
PADDLE_ENFORCE_LE(ell_count, 1,
platform::errors::InvalidArgument(
"An index can only have a single ellipsis ('...')"));
for (int i = 0, dim = 0; i < size; ++i) {
PyObject *slice_item = PyTuple_GetItem(index, i);
......@@ -639,7 +646,7 @@ static void ParseIndexingSlice(
}
// valid_index is the number of dimensions exclude None index
const int valid_indexs = size - none_axes->size();
const int valid_indexs = size - none_axes->size() - ell_count;
PADDLE_ENFORCE_EQ(valid_indexs <= rank, true,
platform::errors::InvalidArgument(
"Too many indices (%d) for tensor of dimension %d.",
......
......@@ -702,6 +702,11 @@ class TestVarBase(unittest.TestCase):
assert_getitem_ellipsis_index(var_fp32, np_fp32_value)
assert_getitem_ellipsis_index(var_int, np_int_value)
# test 1 dim tensor
var_one_dim = paddle.to_tensor([1, 2, 3, 4])
self.assertTrue(
np.array_equal(var_one_dim[..., 0].numpy(), np.array([1])))
def _test_none_index(self):
shape = (8, 64, 5, 256)
np_value = np.random.random(shape).astype('float32')
......
......@@ -226,19 +226,22 @@ class TestVariable(unittest.TestCase):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
y = paddle.assign([1, 2, 3, 4])
out1 = x[0:, ..., 1:]
out2 = x[0:, ...]
out3 = x[..., 1:]
out4 = x[...]
out5 = x[[1, 0], [0, 0]]
out6 = x[([1, 0], [0, 0])]
out7 = y[..., 0]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out1, out2, out3, out4, out5, out6])
result = exe.run(prog,
fetch_list=[out1, out2, out3, out4, out5, out6, out7])
expected = [
data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...],
data[[1, 0], [0, 0]], data[([1, 0], [0, 0])]
data[[1, 0], [0, 0]], data[([1, 0], [0, 0])], np.array([1])
]
self.assertTrue((result[0] == expected[0]).all())
......@@ -247,6 +250,7 @@ class TestVariable(unittest.TestCase):
self.assertTrue((result[3] == expected[3]).all())
self.assertTrue((result[4] == expected[4]).all())
self.assertTrue((result[5] == expected[5]).all())
self.assertTrue((result[6] == expected[6]).all())
with self.assertRaises(IndexError):
res = x[[1.2, 0]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册