未验证 提交 60c5adaa 编写于 作者: W WeiXin 提交者: GitHub

support numpy dtype and polish code of list index. (#35404)

* support numpy dtype and polish code of list index.

* polish code.
上级 5675042d
......@@ -331,7 +331,14 @@ GetVarBaseListFromPyHandle(const py::handle &handle) {
return result;
}
static bool IsNumpyType(PyObject *obj) {
// It is not a good way to judge the type of obj by its type'name. Maybe using
// `PyArray_IsScalar` will be better. However, this interface cannot be used
// by including pybind11, and it needs to compile with numpy.
auto type_name = std::string(Py_TYPE(obj)->tp_name);
return type_name == "numpy.int64" || type_name == "numpy.longlong" ||
type_name == "numpy.int32" || type_name == "numpy.int16";
}
static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
const PyNameVarBaseMap &map) {
imperative::NameVarBaseMap result;
......@@ -372,7 +379,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
if (r->step == Py_None) {
*step = 1;
} else {
if (PyCheckInteger(r->step)) {
if (PyCheckInteger(r->step) || IsNumpyType(r->step)) {
*step = PyLong_AsLong(r->step);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -384,7 +391,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
if (r->start == Py_None) {
*start = *step < 0 ? length - 1 : 0;
} else {
if (PyCheckInteger(r->start)) {
if (PyCheckInteger(r->start) || IsNumpyType(r->start)) {
*start = PyLong_AsLong(r->start);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -398,7 +405,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
if (r->stop == Py_None) {
*stop = *step < 0 ? -1 : length;
} else {
if (PyCheckInteger(r->stop)) {
if (PyCheckInteger(r->stop) || IsNumpyType(r->stop)) {
*stop = PyLong_AsLong(r->stop);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -456,7 +463,7 @@ static void ParseIndexingSlice(
infer_flags->push_back(1);
int dim_len = shape[dim];
if (PyCheckInteger(slice_item)) {
if (PyCheckInteger(slice_item) || IsNumpyType(slice_item)) {
// integer, PyLong_AsLong supports both int and long
int start = static_cast<int>(PyLong_AsLong(slice_item));
auto s_t = start;
......
......@@ -544,7 +544,7 @@ def monkey_patch_varbase():
return array
def contain_tensor(item):
if not isinstance(item, tuple):
if not isinstance(item, (tuple, list)):
item = [item]
for slice_item in item:
......@@ -554,20 +554,21 @@ def monkey_patch_varbase():
or isinstance(slice_item.step, Variable):
return True
else:
if isinstance(slice_item, Variable):
if isinstance(slice_item,
Variable) and Variable.dtype != paddle.bool:
return True
return False
def __getitem__(self, item):
def is_list_tuple(index, contain_type):
def _is_list_tuple(item):
if not (isinstance(item, (list, tuple)) or
type(item) == contain_type):
return False
if isinstance(item, (tuple, list)):
for s in item:
if not _is_list_tuple(s):
return False
else:
if type(item) != contain_type:
return False
return True
if not isinstance(index, (tuple, list)):
......@@ -599,7 +600,28 @@ def monkey_patch_varbase():
return False
if contain_tensor_or_list(item):
def is_combine_index(item):
var_type = None
item_type = None
if isinstance(item, (tuple, list)):
for slice_item in item:
if item_type is None:
item_type = type(slice_item)
else:
if type(slice_item) != item_type:
return True
if isinstance(slice_item, Variable):
if var_type is None:
var_type = slice_item.dtype
else:
if var_type != slice_item.dtype:
return True
return False
return False
if contain_tensor_or_list(item) and not is_combine_index(item):
# To reuse code with static graph,
# Call _setitem_impl_ when item contains tensor or list.
return _setitem_impl_(self, item, value)
......
......@@ -779,6 +779,40 @@ class TestVarBase(unittest.TestCase):
for i, e in enumerate(w):
self.assertTrue(np.array_equal(e.numpy(), np_value[i]))
def _test_numpy_index(self):
array = np.arange(120).reshape([4, 5, 6])
t = paddle.to_tensor(array)
self.assertTrue(np.array_equal(t[np.longlong(0)].numpy(), array[0]))
self.assertTrue(
np.array_equal(t[np.longlong(0):np.longlong(4):np.longlong(2)]
.numpy(), array[0:4:2]))
self.assertTrue(np.array_equal(t[np.int64(0)].numpy(), array[0]))
self.assertTrue(
np.array_equal(t[np.int32(1):np.int32(4):np.int32(2)].numpy(),
array[1:4:2]))
self.assertTrue(
np.array_equal(t[np.int16(0):np.int16(4):np.int16(2)].numpy(),
array[0:4:2]))
def _test_list_index(self):
# case1:
array = np.arange(120).reshape([6, 5, 4])
x = paddle.to_tensor(array)
py_idx = [[0, 2, 0, 1, 3], [0, 0, 1, 2, 0]]
idx = [paddle.to_tensor(py_idx[0]), paddle.to_tensor(py_idx[1])]
self.assertTrue(np.array_equal(x[idx].numpy(), array[py_idx]))
self.assertTrue(np.array_equal(x[py_idx].numpy(), array[py_idx]))
# case2:
tensor_x = paddle.to_tensor(
np.zeros(12).reshape(2, 6).astype(np.float32))
tensor_y1 = paddle.zeros([1]) + 2
tensor_y2 = paddle.zeros([1]) + 5
tensor_x[:, tensor_y1:tensor_y2] = 42
res = tensor_x.numpy()
exp = np.array([[0., 0., 42., 42., 42., 0.],
[0., 0., 42., 42., 42., 0.]])
self.assertTrue(np.array_equal(res, exp))
def test_slice(self):
with fluid.dygraph.guard():
self._test_slice()
......@@ -787,6 +821,8 @@ class TestVarBase(unittest.TestCase):
self._test_for_getitem_ellipsis_index()
self._test_none_index()
self._test_bool_index()
self._test_numpy_index()
self._test_list_index()
var = fluid.dygraph.to_variable(self.array)
self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册