未验证 提交 82339ed1 编写于 作者: Z zyfncg 提交者: GitHub

Support getitem by ellipsis index in dynamic mode (#34267)

* Support getitem by ellipsis index in dynamic mode

* change some code style
上级 0438b604
......@@ -432,19 +432,24 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
const auto &shape = tensor->dims();
const int rank = shape.size();
const int size = PyTuple_GET_SIZE(index);
// specified_dims is the number of dimensions which indexed by Interger,
// Slices.
int specified_dims = 0;
for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index, dim);
if (PyCheckInteger(slice_item) || PySlice_Check(slice_item)) {
specified_dims++;
}
}
PADDLE_ENFORCE_EQ(
size <= rank, true,
platform::errors::InvalidArgument(
"too many indices (%d) for tensor of dimension %d", size, rank));
for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index, dim);
PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item),
true,
platform::errors::InvalidArgument(
"Currently, VarBase.__getitem__() only allows "
"indexing by Integers, Slices, and tuples of "
"these types, but received %s in %dth slice item",
std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
for (int i = 0, dim = 0; i < size; ++i) {
PyObject *slice_item = PyTuple_GetItem(index, i);
infer_flags->push_back(1);
int dim_len = shape[dim];
if (PyCheckInteger(slice_item)) {
......@@ -467,7 +472,8 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
slice_ends->push_back(start + 1);
slice_strides->push_back(1);
decrease_axis->push_back(dim);
} else {
dim++;
} else if (PySlice_Check(slice_item)) {
// slice item
Py_ssize_t start, end, step;
PySliceObject *p = reinterpret_cast<PySliceObject *>(slice_item);
......@@ -475,12 +481,22 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
// :: or : or 0:dim_len:1
if (start == 0 && end == dim_len && step == 1) {
dim++;
continue;
}
slice_axes->push_back(dim);
slice_starts->push_back(start);
slice_ends->push_back(end);
slice_strides->push_back(step);
dim++;
} else if (slice_item == Py_Ellipsis) {
dim += rank - specified_dims;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Currently, VarBase.__getitem__() only allows "
"indexing by Integers, Slices, Ellipsis, and tuples of "
"these types, but received %s in %dth slice item",
std::string(Py_TYPE(slice_item)->tp_name), i + 1));
}
}
if (!PyTuple_Check(_index)) Py_DecRef(index);
......
......@@ -652,6 +652,43 @@ class TestVarBase(unittest.TestCase):
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4]))
def _test_for_getitem_ellipsis_index(self):
shape = (64, 3, 5, 256)
np_fp32_value = np.random.random(shape).astype('float32')
np_int_value = np.random.randint(1, 100, shape)
var_fp32 = paddle.to_tensor(np_fp32_value)
var_int = paddle.to_tensor(np_int_value)
def assert_getitem_ellipsis_index(var_tensor, var_np):
var = [
var_tensor[..., 0].numpy(),
var_tensor[..., 1, 0].numpy(),
var_tensor[0, ..., 1, 0].numpy(),
var_tensor[1, ..., 1].numpy(),
var_tensor[2, ...].numpy(),
var_tensor[2, 0, ...].numpy(),
var_tensor[2, 0, 1, ...].numpy(),
var_tensor[...].numpy(),
var_tensor[:, ..., 100].numpy(),
]
self.assertTrue(np.array_equal(var[0], var_np[..., 0]))
self.assertTrue(np.array_equal(var[1], var_np[..., 1, 0]))
self.assertTrue(np.array_equal(var[2], var_np[0, ..., 1, 0]))
self.assertTrue(np.array_equal(var[3], var_np[1, ..., 1]))
self.assertTrue(np.array_equal(var[4], var_np[2, ...]))
self.assertTrue(np.array_equal(var[5], var_np[2, 0, ...]))
self.assertTrue(np.array_equal(var[6], var_np[2, 0, 1, ...]))
self.assertTrue(np.array_equal(var[7], var_np[...]))
self.assertTrue(np.array_equal(var[8], var_np[:, ..., 100]))
var_fp32 = paddle.to_tensor(np_fp32_value)
var_int = paddle.to_tensor(np_int_value)
assert_getitem_ellipsis_index(var_fp32, np_fp32_value)
assert_getitem_ellipsis_index(var_int, np_int_value)
def _test_for_var(self):
np_value = np.random.random((30, 100, 100)).astype('float32')
w = fluid.dygraph.to_variable(np_value)
......@@ -664,6 +701,7 @@ class TestVarBase(unittest.TestCase):
self._test_slice()
self._test_slice_for_tensor_attr()
self._test_for_var()
self._test_for_getitem_ellipsis_index()
var = fluid.dygraph.to_variable(self.array)
self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册