未验证 提交 e7df47ec 编写于 作者: W WeiXin 提交者: GitHub

support tensor index. (#34824)

* polish code

* polish code.

* polish code.

* polish code.

* polish code.
上级 678a259a
...@@ -815,7 +815,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -815,7 +815,7 @@ void BindImperative(py::module *m_ptr) {
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor")) .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
.def("__init__", &InitVarBaseFromNumpyWithKwargs) .def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def("__setitem__", .def("__setitem_varbase__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index, [](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
py::object &value_obj) { py::object &value_obj) {
VLOG(4) << "Call __setitem__"; VLOG(4) << "Call __setitem__";
......
...@@ -22,7 +22,7 @@ import paddle ...@@ -22,7 +22,7 @@ import paddle
from .. import framework from .. import framework
from .. import core from .. import core
from .. import unique_name from .. import unique_name
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_ from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_
from .base import switch_to_static_graph from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss from .parallel import scale_loss
...@@ -543,23 +543,41 @@ def monkey_patch_varbase(): ...@@ -543,23 +543,41 @@ def monkey_patch_varbase():
array = array.astype(dtype) array = array.astype(dtype)
return array return array
def contain_tensor(item):
if not isinstance(item, tuple):
item = [item]
for slice_item in item:
if isinstance(slice_item, slice):
if isinstance(slice_item.start, Variable) \
or isinstance(slice_item.stop, Variable) \
or isinstance(slice_item.step, Variable):
return True
else:
if isinstance(slice_item, Variable):
return True
return False
def __getitem__(self, item): def __getitem__(self, item):
def contain_tensor(item): def is_list_tuple(index, contain_type):
if not isinstance(item, tuple): def _is_list_tuple(item):
item = [item] if not (isinstance(item, (list, tuple)) or
type(item) == contain_type):
for slice_item in item: return False
if isinstance(slice_item, slice): if isinstance(item, (tuple, list)):
if isinstance(slice_item.start, Variable) \ for s in item:
or isinstance(slice_item.stop, Variable) \ if not _is_list_tuple(s):
or isinstance(slice_item.step, Variable): return False
return True return True
else:
if isinstance(slice_item, Variable):
return True
return False
if contain_tensor(item): if not isinstance(index, (tuple, list)):
return False
for s in index:
if not _is_list_tuple(s):
return False
return True
if contain_tensor(item) or is_list_tuple(item, int):
# 1. Call _getitem_impl_ when item contains tensor. # 1. Call _getitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor. # Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return _getitem_impl_(self, item) return _getitem_impl_(self, item)
...@@ -568,6 +586,17 @@ def monkey_patch_varbase(): ...@@ -568,6 +586,17 @@ def monkey_patch_varbase():
# 2. Call c++ func getitem_index_not_tensor to speedup. # 2. Call c++ func getitem_index_not_tensor to speedup.
return self._getitem_index_not_tensor(item) return self._getitem_index_not_tensor(item)
def __setitem__(self, item, value):
if contain_tensor(item):
# 1. Call _setitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return _setitem_impl_(self, item, value)
else:
# 2. Call c++ func __setitem_varbase__ to speedup.
return self.__setitem_varbase__(item, value)
for method_name, method in ( for method_name, method in (
("__bool__", __bool__), ("__nonzero__", __nonzero__), ("__bool__", __bool__), ("__nonzero__", __nonzero__),
("_to_static_var", _to_static_var), ("set_value", set_value), ("_to_static_var", _to_static_var), ("set_value", set_value),
...@@ -577,7 +606,8 @@ def monkey_patch_varbase(): ...@@ -577,7 +606,8 @@ def monkey_patch_varbase():
("__str__", __str__), ("__repr__", __str__), ("__str__", __str__), ("__repr__", __str__),
("__deepcopy__", __deepcopy__), ("__module__", "paddle"), ("__deepcopy__", __deepcopy__), ("__module__", "paddle"),
("__name__", "Tensor"), ("__array__", __array__), ("__name__", "Tensor"), ("__array__", __array__),
("__getitem__", __getitem__), ("item", item)): ("__getitem__", __getitem__), ("item", item),
("__setitem__", __setitem__)):
setattr(core.VarBase, method_name, method) setattr(core.VarBase, method_name, method)
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class. # NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from functools import reduce
import paddle import paddle
from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode
import paddle import paddle
...@@ -228,21 +230,25 @@ class TestVariable(unittest.TestCase): ...@@ -228,21 +230,25 @@ class TestVariable(unittest.TestCase):
out2 = x[0:, ...] out2 = x[0:, ...]
out3 = x[..., 1:] out3 = x[..., 1:]
out4 = x[...] out4 = x[...]
out5 = x[[1, 0], [0, 0]]
out6 = x[([1, 0], [0, 0])]
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out1, out2, out3, out4]) result = exe.run(prog, fetch_list=[out1, out2, out3, out4, out5, out6])
expected = [data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...]] expected = [
data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...],
data[[1, 0], [0, 0]], data[([1, 0], [0, 0])]
]
self.assertTrue((result[0] == expected[0]).all()) self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all()) self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all()) self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all()) self.assertTrue((result[3] == expected[3]).all())
self.assertTrue((result[4] == expected[4]).all())
self.assertTrue((result[5] == expected[5]).all())
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
res = x[[1, 0], [0, 0]]
with self.assertRaises(TypeError):
res = x[[1.2, 0]] res = x[[1.2, 0]]
def _test_slice_index_list_bool(self, place): def _test_slice_index_list_bool(self, place):
...@@ -472,5 +478,455 @@ class TestVariableSlice(unittest.TestCase): ...@@ -472,5 +478,455 @@ class TestVariableSlice(unittest.TestCase):
self._test_item_none_and_decrease(place) self._test_item_none_and_decrease(place)
class TestListIndex(unittest.TestCase):
def numel(self, shape):
return reduce(lambda x, y: x * y, shape)
def test_static_graph_list_index(self):
paddle.enable_static()
inps_shape = [3, 4, 5, 2]
array = np.arange(
self.numel(inps_shape), dtype='float32').reshape(inps_shape)
index_shape = [3, 3, 2, 1]
index = np.arange(self.numel(index_shape)).reshape(index_shape)
for _ in range(3):
program = paddle.static.Program()
index_mod = (index % (array.shape[0])).tolist()
with paddle.static.program_guard(program):
x = paddle.static.data(
name='x', shape=array.shape, dtype='float32')
y = x[index_mod]
place = paddle.fluid.CPUPlace(
) if not paddle.fluid.core.is_compiled_with_cuda(
) else paddle.fluid.CUDAPlace(0)
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
fetch_list = [y.name]
getitem_np = array[index_mod]
getitem_pp = exe.run(prog,
feed={x.name: array},
fetch_list=fetch_list)
self.assertTrue(np.array_equal(getitem_np, getitem_pp[0]))
array = array[0]
index = index[0]
def test_dygraph_list_index(self):
paddle.disable_static()
inps_shape = [3, 4, 5, 3]
array = np.arange(self.numel(inps_shape)).reshape(inps_shape)
index_shape = [2, 3, 4, 5, 6]
index = np.arange(self.numel(index_shape)).reshape(index_shape)
for _ in range(len(inps_shape) - 1):
pt = paddle.to_tensor(array)
index_mod = (index % (array.shape[-1])).tolist()
try:
getitem_np = array[index_mod]
except:
with self.assertRaises(ValueError):
getitem_pp = pt[index_mod]
array = array[0]
index = index[0]
continue
getitem_pp = pt[index_mod]
self.assertTrue(np.array_equal(getitem_np, getitem_pp.numpy()))
array = array[0]
index = index[0]
def test_static_graph_list_index_muti_dim(self):
paddle.enable_static()
inps_shape = [3, 4, 5]
array = np.arange(
self.numel(inps_shape), dtype='float32').reshape(inps_shape)
index_shape = [2, 2]
index1 = np.arange(self.numel(index_shape)).reshape(index_shape)
index2 = np.arange(self.numel(index_shape)).reshape(index_shape) + 2
value_shape = [3, 2, 2, 3]
value_np = np.arange(
self.numel(value_shape), dtype='float32').reshape(value_shape) + 100
index_mod1 = (index1 % (min(array.shape))).tolist()
index_mod2 = (index2 % (min(array.shape))).tolist()
program = paddle.static.Program()
with paddle.static.program_guard(program):
x = paddle.static.data(name='x', shape=array.shape, dtype='float32')
value = paddle.static.data(
name='value', shape=value_np.shape, dtype='float32')
index1 = paddle.static.data(
name='index1', shape=index1.shape, dtype='int32')
index2 = paddle.static.data(
name='index2', shape=index2.shape, dtype='int32')
y = x[index1, index2]
place = paddle.fluid.CPUPlace(
) if not paddle.fluid.core.is_compiled_with_cuda(
) else paddle.fluid.CUDAPlace(0)
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
fetch_list = [y.name]
array2 = array.copy()
y2 = array2[index_mod1, index_mod2]
getitem_pp = exe.run(prog,
feed={
x.name: array,
index1.name: index_mod1,
index2.name: index_mod2
},
fetch_list=fetch_list)
self.assertTrue(
np.array_equal(y2, getitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(y2, getitem_pp[0]))
def test_dygraph_list_index_muti_dim(self):
paddle.disable_static()
inps_shape = [3, 4, 5]
array = np.arange(
self.numel(inps_shape), dtype='float32').reshape(inps_shape)
index_shape = [2, 2]
index1 = np.arange(self.numel(index_shape)).reshape(index_shape)
index2 = np.arange(self.numel(index_shape)).reshape(index_shape) + 2
value_shape = [3, 2, 2, 3]
value_np = np.arange(
self.numel(value_shape), dtype='float32').reshape(value_shape) + 100
index_mod1 = (index1 % (min(array.shape))).tolist()
index_mod2 = (index2 % (min(array.shape))).tolist()
x = paddle.to_tensor(array)
index_t1 = paddle.to_tensor(index_mod1)
index_t2 = paddle.to_tensor(index_mod2)
y_np = array[index_t1, index_t2]
y = x[index_t1, index_t2]
self.assertTrue(np.array_equal(y.numpy(), y_np))
def run_setitem_list_index(self, array, index, value_np):
x = paddle.static.data(name='x', shape=array.shape, dtype='float32')
value = paddle.static.data(
name='value', shape=value_np.shape, dtype='float32')
x[index] = value
y = x
place = paddle.fluid.CPUPlace()
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
fetch_list = [y.name]
array2 = array.copy()
try:
array2[index] = value_np
except:
with self.assertRaises(ValueError):
setitem_pp = exe.run(
prog,
feed={x.name: array,
value.name: value_np},
fetch_list=fetch_list)
return
setitem_pp = exe.run(prog,
feed={x.name: array,
value.name: value_np},
fetch_list=fetch_list)
self.assertTrue(
np.array_equal(array2, setitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(array2, setitem_pp[0]))
def test_static_graph_setitem_list_index(self):
paddle.enable_static()
# case 1:
inps_shape = [3, 4, 5, 2, 3]
array = np.arange(
self.numel(inps_shape), dtype='float32').reshape(inps_shape)
index_shape = [3, 3, 1, 2]
index = np.arange(self.numel(index_shape)).reshape(index_shape)
value_shape = inps_shape[3:]
value_np = np.arange(
self.numel(value_shape), dtype='float32').reshape(value_shape) + 100
for _ in range(3):
program = paddle.static.Program()
index_mod = (index % (min(array.shape))).tolist()
with paddle.static.program_guard(program):
self.run_setitem_list_index(array, index_mod, value_np)
array = array[0]
index = index[0]
# case 2:
inps_shape = [3, 4, 5, 4, 3]
array = np.arange(
self.numel(inps_shape), dtype='float32').reshape(inps_shape)
index_shape = [4, 3, 2, 2]
index = np.arange(self.numel(index_shape)).reshape(index_shape)
value_shape = [3]
value_np = np.arange(
self.numel(value_shape), dtype='float32').reshape(value_shape) + 100
for _ in range(4):
program = paddle.static.Program()
index_mod = (index % (min(array.shape))).tolist()
with paddle.static.program_guard(program):
self.run_setitem_list_index(array, index_mod, value_np)
array = array[0]
index = index[0]
# case 3:
inps_shape = [3, 4, 5, 3, 3]
array = np.arange(
self.numel(inps_shape), dtype='float32').reshape(inps_shape)
index_shape = [4, 3, 2, 2]
index = np.arange(self.numel(index_shape)).reshape(index_shape)
value_shape = [3, 2, 2, 3]
value_np = np.arange(
self.numel(value_shape), dtype='float32').reshape(value_shape) + 100
index_mod = (index % (min(array.shape))).tolist()
self.run_setitem_list_index(array, index_mod, value_np)
def test_static_graph_tensor_index_setitem_muti_dim(self):
paddle.enable_static()
inps_shape = [3, 4, 5, 4]
array = np.arange(
self.numel(inps_shape), dtype='float32').reshape(inps_shape)
index_shape = [2, 3, 4]
index1 = np.arange(
self.numel(index_shape), dtype='int32').reshape(index_shape)
index2 = np.arange(
self.numel(index_shape), dtype='int32').reshape(index_shape) + 2
value_shape = [4]
value_np = np.arange(
self.numel(value_shape), dtype='float32').reshape(value_shape) + 100
for _ in range(3):
index_mod1 = index1 % (min(array.shape))
index_mod2 = index2 % (min(array.shape))
array2 = array.copy()
array2[index_mod1, index_mod2] = value_np
array3 = array.copy()
array3[index_mod1] = value_np
program = paddle.static.Program()
with paddle.static.program_guard(program):
x1 = paddle.static.data(
name='x1', shape=array.shape, dtype='float32')
x2 = paddle.static.data(
name='x2', shape=array.shape, dtype='float32')
value = paddle.static.data(
name='value', shape=value_np.shape, dtype='float32')
index_1 = paddle.static.data(
name='index_1', shape=index1.shape, dtype='int32')
index_2 = paddle.static.data(
name='index_2', shape=index2.shape, dtype='int32')
x1[index_1, index_2] = value
x2[index_1] = value
place = paddle.fluid.CPUPlace(
) if not paddle.fluid.core.is_compiled_with_cuda(
) else paddle.fluid.CUDAPlace(0)
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
fetch_list = [x1.name, x2.name]
setitem_pp = exe.run(prog,
feed={
x1.name: array,
x2.name: array,
value.name: value_np,
index_1.name: index_mod1,
index_2.name: index_mod2
},
fetch_list=fetch_list)
self.assertTrue(
np.array_equal(array2, setitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(array2,
setitem_pp[0]))
self.assertTrue(
np.array_equal(array3, setitem_pp[1]),
msg='\n numpy:{},\n paddle:{}'.format(array3,
setitem_pp[1]))
array = array[0]
index1 = index1[0]
index2 = index2[0]
def test_static_graph_array_index_muti_dim(self):
paddle.enable_static()
inps_shape = [3, 4, 5, 4]
array = np.arange(
self.numel(inps_shape), dtype='float32').reshape(inps_shape)
index_shape = [2, 3, 4]
index1 = np.arange(
self.numel(index_shape), dtype='int32').reshape(index_shape)
index2 = np.arange(
self.numel(index_shape), dtype='int32').reshape(index_shape) + 2
for _ in range(3):
index_mod1 = index1 % (min(array.shape))
index_mod2 = index2 % (min(array.shape))
array2 = array.copy()
array2[index_mod1, index_mod2] = 1
y_np1 = array2[index_mod2, index_mod1]
array3 = array.copy()
array3[index_mod1] = 2.5
y_np2 = array3[index_mod2]
program = paddle.static.Program()
with paddle.static.program_guard(program):
x1 = paddle.static.data(
name='x1', shape=array.shape, dtype='float32')
x2 = paddle.static.data(
name='x2', shape=array.shape, dtype='float32')
x1[index_mod1, index_mod2] = 1
x2[index_mod1] = 2.5
y1 = x1[index_mod2, index_mod1]
y2 = x2[index_mod2]
place = paddle.fluid.CPUPlace(
) if not paddle.fluid.core.is_compiled_with_cuda(
) else paddle.fluid.CUDAPlace(0)
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
fetch_list = [x1.name, x2.name, y1.name, y2.name]
setitem_pp = exe.run(prog,
feed={x1.name: array,
x2.name: array},
fetch_list=fetch_list)
self.assertTrue(
np.array_equal(array2, setitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(array2,
setitem_pp[0]))
self.assertTrue(
np.array_equal(array3, setitem_pp[1]),
msg='\n numpy:{},\n paddle:{}'.format(array3,
setitem_pp[1]))
self.assertTrue(
np.array_equal(y_np1, setitem_pp[2]),
msg='\n numpy:{},\n paddle:{}'.format(y_np1, setitem_pp[2]))
self.assertTrue(
np.array_equal(y_np2, setitem_pp[3]),
msg='\n numpy:{},\n paddle:{}'.format(y_np2, setitem_pp[3]))
array = array[0]
index1 = index1[0]
index2 = index2[0]
def test_dygraph_array_index_muti_dim(self):
paddle.disable_static()
inps_shape = [3, 4, 5, 4]
array = np.arange(
self.numel(inps_shape), dtype='float32').reshape(inps_shape)
index_shape = [2, 3, 4]
index1 = np.arange(
self.numel(index_shape), dtype='int32').reshape(index_shape)
index2 = np.arange(
self.numel(index_shape), dtype='int32').reshape(index_shape) + 2
for _ in range(3):
index_mod1 = index1 % (min(array.shape))
index_mod2 = index2 % (min(array.shape))
index_mod_t1 = paddle.to_tensor(index_mod1)
index_mod_t2 = paddle.to_tensor(index_mod2)
# 2 dim getitem
array1 = array.copy()
y_np1 = array1[index_mod2, index_mod1]
tensor1 = paddle.to_tensor(array)
y_t1 = tensor1[index_mod_t2, index_mod_t1]
self.assertTrue(
np.array_equal(y_t1.numpy(), y_np1),
msg='\n numpy:{},\n paddle:{}'.format(y_np1, y_t1.numpy()))
# 1 dim getitem
array2 = array.copy()
y_np2 = array2[index_mod2]
tensor2 = paddle.to_tensor(array)
y_t2 = tensor2[index_mod_t2]
self.assertTrue(
np.array_equal(y_t2.numpy(), y_np2),
msg='\n numpy:{},\n paddle:{}'.format(y_np2, y_t2.numpy()))
# 2 dim setitem
array1 = array.copy()
array1[index_mod1, index_mod2] = 1
tensor1[index_mod_t1, index_mod_t2] = 1
self.assertTrue(
np.array_equal(tensor1.numpy(), array1),
msg='\n numpy:{},\n paddle:{}'.format(array1, tensor1.numpy()))
# 1 dim setitem
array2 = array.copy()
array2[index_mod1] = 2.5
tensor2[index_mod_t1] = 2.5
self.assertTrue(
np.array_equal(tensor2.numpy(), array2),
msg='\n numpy:{},\n paddle:{}'.format(array2, tensor2.numpy()))
array = array[0]
index1 = index1[0]
index2 = index2[0]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,10 +16,172 @@ import sys ...@@ -16,10 +16,172 @@ import sys
import numpy as np import numpy as np
from . import unique_name from . import unique_name
from . import core from . import core
import paddle
MAX_INTEGER = 2**31 - 1 MAX_INTEGER = 2**31 - 1
def is_list_tuple(index, contain_type):
def _is_list_tuple(item):
if not (isinstance(item, (list, tuple)) or type(item) == contain_type):
return False
if isinstance(item, (tuple, list)):
for s in item:
if not _is_list_tuple(s):
return False
return True
if not isinstance(index, (tuple, list)):
return False
for s in index:
if not _is_list_tuple(s):
return False
return True
def is_one_dim_list(index, contain_type):
if isinstance(index, list):
for i in index:
if not isinstance(i, contain_type):
return False
else:
return False
return True
def get_list_index_shape(var_dims, index_dims):
var_dims_size = len(var_dims)
index_dims_size = len(index_dims)
out_dims_size = var_dims_size - index_dims[0] + index_dims_size - 1
out_dims_shape = [1] * out_dims_size
out_dims_shape[:index_dims_size - 1] = index_dims[1:]
out_dims_shape[index_dims_size - 1:] = var_dims[index_dims[0]:]
return out_dims_shape
class SliceInfo:
def __init__(self):
self.pre_shape = None
self.indexes = []
def update(self, index):
if is_list_tuple(index, int) or isinstance(index, (
paddle.fluid.Variable, np.ndarray)):
# convert index to Tensor
if not isinstance(index, paddle.fluid.Variable):
index = paddle.assign(index)
self.indexes.append(index)
if self.pre_shape is None:
self.pre_shape = index.shape
else:
if self.pre_shape != index.shape:
# broadcast
cur_shape = paddle.broadcast_shape(self.pre_shape,
index.shape)
for i in range(len(self.indexes)):
self.indexes[i] = paddle.broadcast_to(self.indexes[i],
cur_shape)
self.pre_shape = self.indexes[-1].shape
else:
raise ValueError(
"Index should be list/tuple of int or Tensor, but received {}.".
format(index))
def shape_stride(self, shape):
s = [1] * len(shape)
for i in range(len(shape) - 2, -1, -1):
s[i] = shape[i + 1] * s[i + 1]
return s
def numel(self, shape):
return reduce(lambda x, y: x * y, shape)
def get_offset_stride(self, tensor_shape):
for index in self.indexes:
if not isinstance(index, paddle.fluid.Variable):
raise ValueError(
"only support list/tensor index, but received {}.".format(
type(index)))
if len(self.indexes) <= len(tensor_shape) or len(self.indexes) == 1:
shape = paddle.stack(self.indexes)
axes = list(range(1, len(self.pre_shape) + 1)) + [0, ]
else:
raise ValueError(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".
format(len(tensor_shape), self.pre_shape[0]))
shape_transpose = paddle.transpose(shape, axes)
return shape_transpose
def get_item(self, tensor):
shape_transpose = self.get_offset_stride(tensor.shape)
index = paddle.assign(shape_transpose)
return paddle.gather_nd(tensor, index)
def set_item(self, tensor_origin, value):
if not isinstance(value, paddle.fluid.Variable):
value = paddle.assign(value)
tensor_type = None
if tensor_origin.dtype in [
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64
]:
tensor = tensor_origin
else:
tensor_type = tensor_origin.dtype
tensor = tensor_origin.astype(core.VarDesc.VarType.FP32)
if value.dtype != tensor.dtype:
value = value.astype(tensor.dtype)
shape_transpose = self.get_offset_stride(tensor_origin.shape)
index = paddle.assign(shape_transpose)
gather_tensor_shape = get_list_index_shape(
tensor.shape, [len(self.indexes), ] + list(self.indexes[-1].shape))
value_dims_bd = [1, ] * len(gather_tensor_shape)
value_dims_bd[-len(value.shape):] = list(value.shape)
for i in range(len(gather_tensor_shape)):
if not (value_dims_bd[i] == gather_tensor_shape[i] or
value_dims_bd[i] == 1):
raise ValueError("{} can not broadcast into {}".format(
value.shape, gather_tensor_shape))
value_broadcast = paddle.broadcast_to(value, gather_tensor_shape)
value_1d = value_broadcast.reshape([-1] + gather_tensor_shape[len(
index.shape) - 1:])
index_1d = index.reshape([-1, index.shape[-1]])
tensor_stride = paddle.assign(
self.shape_stride(tensor.shape[:index.shape[-1]]))
inds = []
for i in range(index_1d.shape[0]):
temp = (index_1d[i] * tensor_stride).sum()
inds.append(temp)
index_1d = paddle.stack(inds).reshape([-1])
t_reshape = tensor.reshape([-1] + list(tensor.shape[index.shape[-1]:]))
out = paddle.scatter(t_reshape, index_1d, value_1d)
if tensor_type is not None:
out = out.astype(tensor_type)
tensor_origin[:] = out.reshape(tensor_origin.shape)
return tensor_origin
def replace_ellipsis(var, item): def replace_ellipsis(var, item):
from .framework import Variable from .framework import Variable
# Use slice(None) to replace Ellipsis. # Use slice(None) to replace Ellipsis.
...@@ -32,7 +194,9 @@ def replace_ellipsis(var, item): ...@@ -32,7 +194,9 @@ def replace_ellipsis(var, item):
item = list(item) item = list(item)
# Remove Variable to skip bug when counting Ellipsis # Remove Variable to skip bug when counting Ellipsis
item_remove_var = [ele for ele in item if not isinstance(ele, Variable)] item_remove_var = [
ele for ele in item if not isinstance(ele, (Variable, np.ndarray))
]
ell_count = item_remove_var.count(Ellipsis) ell_count = item_remove_var.count(Ellipsis)
if ell_count == 0: if ell_count == 0:
return item return item
...@@ -99,6 +263,9 @@ def _getitem_impl_(var, item): ...@@ -99,6 +263,9 @@ def _getitem_impl_(var, item):
Sliced variable Sliced variable
""" """
from .framework import default_main_program, Variable from .framework import default_main_program, Variable
if isinstance(item, list):
if not is_one_dim_list(item, int):
item = tuple(item)
if not isinstance(item, tuple): if not isinstance(item, tuple):
item = (item, ) item = (item, )
...@@ -113,6 +280,7 @@ def _getitem_impl_(var, item): ...@@ -113,6 +280,7 @@ def _getitem_impl_(var, item):
use_strided_slice = False use_strided_slice = False
item, none_axes = replace_none(item) item, none_axes = replace_none(item)
item = replace_ellipsis(var, item) item = replace_ellipsis(var, item)
slice_info = SliceInfo()
for dim, slice_item in enumerate(item): for dim, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item): if is_integer_or_scalar_tensor(slice_item):
...@@ -151,6 +319,11 @@ def _getitem_impl_(var, item): ...@@ -151,6 +319,11 @@ def _getitem_impl_(var, item):
elif isinstance(slice_item, list): elif isinstance(slice_item, list):
all_bool = True all_bool = True
if is_list_tuple(slice_item, int):
slice_info.update(slice_item)
continue
for i in slice_item: for i in slice_item:
if type(i) is int: if type(i) is int:
all_bool = False all_bool = False
...@@ -188,35 +361,43 @@ def _getitem_impl_(var, item): ...@@ -188,35 +361,43 @@ def _getitem_impl_(var, item):
idx = assign(np.array(slice_item).astype("int32")) idx = assign(np.array(slice_item).astype("int32"))
return index_select(var, index=idx, axis=0) return index_select(var, index=idx, axis=0)
elif isinstance(slice_item, Variable): elif isinstance(slice_item, np.ndarray):
if len(item) != 1: slice_info.update(slice_item)
raise IndexError( continue
"When index contains a Tensor, its length must be 1, but received {}.". elif isinstance(slice_item, (Variable)):
format(len(item))) if len(item) == 1:
from ..tensor import index_select, gather_nd from ..tensor import index_select, gather_nd
from .layers.nn import where from .layers.nn import where
if slice_item.dtype == core.VarDesc.VarType.BOOL: if slice_item.dtype == paddle.bool:
if len(slice_item.shape) > len(var.shape): if len(slice_item.shape) > len(var.shape):
raise IndexError(
"The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(
len(var.shape), len(slice_item.shape)))
for i, dim_len in enumerate(slice_item.shape):
if dim_len != var.shape[i]:
raise IndexError( raise IndexError(
"The dimension of bool index doesn't match indexed array along "\ "The dims of bool index doesn't match indexed array, "
"dimension {}, the target dimension is {}, but received {}.". "the dims of bool index except to be equal or less "
format(i, var.shape[i], dim_len)) "than {}, but received {}.".format(
bool_2_idx = where(slice_item == True) len(var.shape), len(slice_item.shape)))
return gather_nd(var, bool_2_idx) for i, dim_len in enumerate(slice_item.shape):
return index_select(var, index=slice_item, axis=0) if dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "\
"dimension {}, the target dimension is {}, but received {}.".
format(i, var.shape[i], dim_len))
bool_2_idx = where(slice_item == True)
return gather_nd(var, bool_2_idx)
else:
if len(slice_item.shape) == 1:
return index_select(var, index=slice_item, axis=0)
else:
slice_info.update(slice_item)
continue
else:
slice_info.update(slice_item)
continue
else: else:
raise IndexError( raise IndexError(
"Valid index accept int or slice or ellipsis, but received {}.". "Valid index accept int or slice or ellipsis or list, but received {}.".
format(slice_item)) format(slice_item))
axes.append(dim) axes.append(dim)
...@@ -225,6 +406,13 @@ def _getitem_impl_(var, item): ...@@ -225,6 +406,13 @@ def _getitem_impl_(var, item):
steps.append(step) steps.append(step)
use_strided_slice = True if step != 1 else use_strided_slice use_strided_slice = True if step != 1 else use_strided_slice
if slice_info.indexes:
if len(slice_info.indexes) != len(item):
raise IndexError(
"Valid index accept int or slice or ellipsis or list, but received {}.".
format(item))
return slice_info.get_item(var)
inputs = {'Input': [var]} inputs = {'Input': [var]}
attrs = { attrs = {
'axes': axes, 'axes': axes,
...@@ -298,7 +486,9 @@ def _setitem_impl_(var, item, value): ...@@ -298,7 +486,9 @@ def _setitem_impl_(var, item, value):
from .framework import default_main_program, Variable from .framework import default_main_program, Variable
inputs = {'Input': var} inputs = {'Input': var}
if isinstance(item, list):
if not is_one_dim_list(item, int):
item = tuple(item)
# 1. Parse item # 1. Parse item
if not isinstance(item, tuple): if not isinstance(item, tuple):
item = (item, ) item = (item, )
...@@ -311,7 +501,7 @@ def _setitem_impl_(var, item, value): ...@@ -311,7 +501,7 @@ def _setitem_impl_(var, item, value):
item, none_axes = replace_none(item) item, none_axes = replace_none(item)
item = replace_ellipsis(var, item) item = replace_ellipsis(var, item)
slice_info = SliceInfo()
dim = 0 dim = 0
for _, slice_item in enumerate(item): for _, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item): if is_integer_or_scalar_tensor(slice_item):
...@@ -319,6 +509,16 @@ def _setitem_impl_(var, item, value): ...@@ -319,6 +509,16 @@ def _setitem_impl_(var, item, value):
start = slice_item start = slice_item
end = slice_item + 1 if slice_item != -1 else MAX_INTEGER end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
step = 1 step = 1
elif isinstance(slice_item, list):
if not is_list_tuple(slice_item, int):
raise TypeError(
"Only support int or list in index list. But revceived {}.".
format(slice_item))
slice_info.update(slice_item)
continue
elif isinstance(slice_item, (Variable, np.ndarray)):
slice_info.update(slice_item)
continue
elif isinstance(slice_item, slice): elif isinstance(slice_item, slice):
start = slice_item.start start = slice_item.start
...@@ -358,7 +558,12 @@ def _setitem_impl_(var, item, value): ...@@ -358,7 +558,12 @@ def _setitem_impl_(var, item, value):
steps.append(step) steps.append(step)
dim += 1 dim += 1
if slice_info.indexes:
if len(slice_info.indexes) != len(item):
raise IndexError(
"Valid index accept int or slice or ellipsis or list, but received {}.".
format(item))
return slice_info.set_item(var, value)
attrs = { attrs = {
'axes': axes, 'axes': axes,
'starts': starts, 'starts': starts,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册