diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index e6b8239b44208837b815755f08efb248a8a57ac7..0c8d65340f6b99c29b115814692c2e976a433554 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -815,7 +815,7 @@ void BindImperative(py::module *m_ptr) { .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor")) .def("__init__", &InitVarBaseFromNumpyWithKwargs) - .def("__setitem__", + .def("__setitem_varbase__", [](std::shared_ptr &self, py::handle _index, py::object &value_obj) { VLOG(4) << "Call __setitem__"; diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 2fda67e891abfd1ecc1a7f7a0a62cc860349c6fd..4a36445ac0c9479ac7aa48e97f2a6158190e9b74 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -22,7 +22,7 @@ import paddle from .. import framework from .. import core from .. import unique_name -from ..framework import Variable, Parameter, ParamBase, _getitem_impl_ +from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_ from .base import switch_to_static_graph from .math_op_patch import monkey_patch_math_varbase from .parallel import scale_loss @@ -543,23 +543,41 @@ def monkey_patch_varbase(): array = array.astype(dtype) return array + def contain_tensor(item): + if not isinstance(item, tuple): + item = [item] + + for slice_item in item: + if isinstance(slice_item, slice): + if isinstance(slice_item.start, Variable) \ + or isinstance(slice_item.stop, Variable) \ + or isinstance(slice_item.step, Variable): + return True + else: + if isinstance(slice_item, Variable): + return True + return False + def __getitem__(self, item): - def contain_tensor(item): - if not isinstance(item, tuple): - item = [item] - - for slice_item in item: - if isinstance(slice_item, slice): - if isinstance(slice_item.start, Variable) \ - or isinstance(slice_item.stop, Variable) \ - or isinstance(slice_item.step, Variable): - return True - else: - if isinstance(slice_item, Variable): - return True - return False + def is_list_tuple(index, contain_type): + def _is_list_tuple(item): + if not (isinstance(item, (list, tuple)) or + type(item) == contain_type): + return False + if isinstance(item, (tuple, list)): + for s in item: + if not _is_list_tuple(s): + return False + return True - if contain_tensor(item): + if not isinstance(index, (tuple, list)): + return False + for s in index: + if not _is_list_tuple(s): + return False + return True + + if contain_tensor(item) or is_list_tuple(item, int): # 1. Call _getitem_impl_ when item contains tensor. # Why not call a c++ function ? Because item can't be parsed when it contains tensor. return _getitem_impl_(self, item) @@ -568,6 +586,17 @@ def monkey_patch_varbase(): # 2. Call c++ func getitem_index_not_tensor to speedup. return self._getitem_index_not_tensor(item) + def __setitem__(self, item, value): + + if contain_tensor(item): + # 1. Call _setitem_impl_ when item contains tensor. + # Why not call a c++ function ? Because item can't be parsed when it contains tensor. + return _setitem_impl_(self, item, value) + + else: + # 2. Call c++ func __setitem_varbase__ to speedup. + return self.__setitem_varbase__(item, value) + for method_name, method in ( ("__bool__", __bool__), ("__nonzero__", __nonzero__), ("_to_static_var", _to_static_var), ("set_value", set_value), @@ -577,7 +606,8 @@ def monkey_patch_varbase(): ("__str__", __str__), ("__repr__", __str__), ("__deepcopy__", __deepcopy__), ("__module__", "paddle"), ("__name__", "Tensor"), ("__array__", __array__), - ("__getitem__", __getitem__), ("item", item)): + ("__getitem__", __getitem__), ("item", item), + ("__setitem__", __setitem__)): setattr(core.VarBase, method_name, method) # NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class. diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 0c120100faf0661c7c2fcaf4b6d30183f9a81190..e9e959266dba8e8f9751e3285eb6c272284e4968 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -15,6 +15,8 @@ from __future__ import print_function import unittest +from functools import reduce + import paddle from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode import paddle @@ -228,21 +230,25 @@ class TestVariable(unittest.TestCase): out2 = x[0:, ...] out3 = x[..., 1:] out4 = x[...] + out5 = x[[1, 0], [0, 0]] + out6 = x[([1, 0], [0, 0])] exe = paddle.static.Executor(place) - result = exe.run(prog, fetch_list=[out1, out2, out3, out4]) + result = exe.run(prog, fetch_list=[out1, out2, out3, out4, out5, out6]) - expected = [data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...]] + expected = [ + data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...], + data[[1, 0], [0, 0]], data[([1, 0], [0, 0])] + ] self.assertTrue((result[0] == expected[0]).all()) self.assertTrue((result[1] == expected[1]).all()) self.assertTrue((result[2] == expected[2]).all()) self.assertTrue((result[3] == expected[3]).all()) + self.assertTrue((result[4] == expected[4]).all()) + self.assertTrue((result[5] == expected[5]).all()) with self.assertRaises(IndexError): - res = x[[1, 0], [0, 0]] - - with self.assertRaises(TypeError): res = x[[1.2, 0]] def _test_slice_index_list_bool(self, place): @@ -472,5 +478,455 @@ class TestVariableSlice(unittest.TestCase): self._test_item_none_and_decrease(place) +class TestListIndex(unittest.TestCase): + def numel(self, shape): + return reduce(lambda x, y: x * y, shape) + + def test_static_graph_list_index(self): + paddle.enable_static() + + inps_shape = [3, 4, 5, 2] + array = np.arange( + self.numel(inps_shape), dtype='float32').reshape(inps_shape) + + index_shape = [3, 3, 2, 1] + index = np.arange(self.numel(index_shape)).reshape(index_shape) + + for _ in range(3): + program = paddle.static.Program() + + index_mod = (index % (array.shape[0])).tolist() + + with paddle.static.program_guard(program): + x = paddle.static.data( + name='x', shape=array.shape, dtype='float32') + + y = x[index_mod] + + place = paddle.fluid.CPUPlace( + ) if not paddle.fluid.core.is_compiled_with_cuda( + ) else paddle.fluid.CUDAPlace(0) + + prog = paddle.static.default_main_program() + exe = paddle.static.Executor(place) + + exe.run(paddle.static.default_startup_program()) + fetch_list = [y.name] + + getitem_np = array[index_mod] + getitem_pp = exe.run(prog, + feed={x.name: array}, + fetch_list=fetch_list) + self.assertTrue(np.array_equal(getitem_np, getitem_pp[0])) + + array = array[0] + index = index[0] + + def test_dygraph_list_index(self): + paddle.disable_static() + + inps_shape = [3, 4, 5, 3] + array = np.arange(self.numel(inps_shape)).reshape(inps_shape) + + index_shape = [2, 3, 4, 5, 6] + index = np.arange(self.numel(index_shape)).reshape(index_shape) + for _ in range(len(inps_shape) - 1): + + pt = paddle.to_tensor(array) + index_mod = (index % (array.shape[-1])).tolist() + try: + getitem_np = array[index_mod] + + except: + with self.assertRaises(ValueError): + getitem_pp = pt[index_mod] + array = array[0] + index = index[0] + continue + getitem_pp = pt[index_mod] + self.assertTrue(np.array_equal(getitem_np, getitem_pp.numpy())) + + array = array[0] + index = index[0] + + def test_static_graph_list_index_muti_dim(self): + paddle.enable_static() + inps_shape = [3, 4, 5] + array = np.arange( + self.numel(inps_shape), dtype='float32').reshape(inps_shape) + + index_shape = [2, 2] + index1 = np.arange(self.numel(index_shape)).reshape(index_shape) + index2 = np.arange(self.numel(index_shape)).reshape(index_shape) + 2 + + value_shape = [3, 2, 2, 3] + value_np = np.arange( + self.numel(value_shape), dtype='float32').reshape(value_shape) + 100 + + index_mod1 = (index1 % (min(array.shape))).tolist() + index_mod2 = (index2 % (min(array.shape))).tolist() + + program = paddle.static.Program() + with paddle.static.program_guard(program): + + x = paddle.static.data(name='x', shape=array.shape, dtype='float32') + + value = paddle.static.data( + name='value', shape=value_np.shape, dtype='float32') + index1 = paddle.static.data( + name='index1', shape=index1.shape, dtype='int32') + index2 = paddle.static.data( + name='index2', shape=index2.shape, dtype='int32') + + y = x[index1, index2] + + place = paddle.fluid.CPUPlace( + ) if not paddle.fluid.core.is_compiled_with_cuda( + ) else paddle.fluid.CUDAPlace(0) + + prog = paddle.static.default_main_program() + exe = paddle.static.Executor(place) + + exe.run(paddle.static.default_startup_program()) + fetch_list = [y.name] + array2 = array.copy() + + y2 = array2[index_mod1, index_mod2] + + getitem_pp = exe.run(prog, + feed={ + x.name: array, + index1.name: index_mod1, + index2.name: index_mod2 + }, + fetch_list=fetch_list) + + self.assertTrue( + np.array_equal(y2, getitem_pp[0]), + msg='\n numpy:{},\n paddle:{}'.format(y2, getitem_pp[0])) + + def test_dygraph_list_index_muti_dim(self): + paddle.disable_static() + inps_shape = [3, 4, 5] + array = np.arange( + self.numel(inps_shape), dtype='float32').reshape(inps_shape) + + index_shape = [2, 2] + index1 = np.arange(self.numel(index_shape)).reshape(index_shape) + index2 = np.arange(self.numel(index_shape)).reshape(index_shape) + 2 + + value_shape = [3, 2, 2, 3] + value_np = np.arange( + self.numel(value_shape), dtype='float32').reshape(value_shape) + 100 + + index_mod1 = (index1 % (min(array.shape))).tolist() + index_mod2 = (index2 % (min(array.shape))).tolist() + + x = paddle.to_tensor(array) + index_t1 = paddle.to_tensor(index_mod1) + index_t2 = paddle.to_tensor(index_mod2) + + y_np = array[index_t1, index_t2] + y = x[index_t1, index_t2] + self.assertTrue(np.array_equal(y.numpy(), y_np)) + + def run_setitem_list_index(self, array, index, value_np): + x = paddle.static.data(name='x', shape=array.shape, dtype='float32') + + value = paddle.static.data( + name='value', shape=value_np.shape, dtype='float32') + + x[index] = value + y = x + place = paddle.fluid.CPUPlace() + + prog = paddle.static.default_main_program() + exe = paddle.static.Executor(place) + + exe.run(paddle.static.default_startup_program()) + fetch_list = [y.name] + array2 = array.copy() + + try: + array2[index] = value_np + except: + with self.assertRaises(ValueError): + setitem_pp = exe.run( + prog, + feed={x.name: array, + value.name: value_np}, + fetch_list=fetch_list) + return + setitem_pp = exe.run(prog, + feed={x.name: array, + value.name: value_np}, + fetch_list=fetch_list) + + self.assertTrue( + np.array_equal(array2, setitem_pp[0]), + msg='\n numpy:{},\n paddle:{}'.format(array2, setitem_pp[0])) + + def test_static_graph_setitem_list_index(self): + paddle.enable_static() + # case 1: + inps_shape = [3, 4, 5, 2, 3] + array = np.arange( + self.numel(inps_shape), dtype='float32').reshape(inps_shape) + + index_shape = [3, 3, 1, 2] + index = np.arange(self.numel(index_shape)).reshape(index_shape) + + value_shape = inps_shape[3:] + value_np = np.arange( + self.numel(value_shape), dtype='float32').reshape(value_shape) + 100 + + for _ in range(3): + program = paddle.static.Program() + + index_mod = (index % (min(array.shape))).tolist() + + with paddle.static.program_guard(program): + self.run_setitem_list_index(array, index_mod, value_np) + + array = array[0] + index = index[0] + + # case 2: + inps_shape = [3, 4, 5, 4, 3] + array = np.arange( + self.numel(inps_shape), dtype='float32').reshape(inps_shape) + + index_shape = [4, 3, 2, 2] + index = np.arange(self.numel(index_shape)).reshape(index_shape) + + value_shape = [3] + value_np = np.arange( + self.numel(value_shape), dtype='float32').reshape(value_shape) + 100 + + for _ in range(4): + program = paddle.static.Program() + index_mod = (index % (min(array.shape))).tolist() + + with paddle.static.program_guard(program): + self.run_setitem_list_index(array, index_mod, value_np) + + array = array[0] + index = index[0] + + # case 3: + inps_shape = [3, 4, 5, 3, 3] + array = np.arange( + self.numel(inps_shape), dtype='float32').reshape(inps_shape) + + index_shape = [4, 3, 2, 2] + index = np.arange(self.numel(index_shape)).reshape(index_shape) + + value_shape = [3, 2, 2, 3] + value_np = np.arange( + self.numel(value_shape), dtype='float32').reshape(value_shape) + 100 + index_mod = (index % (min(array.shape))).tolist() + self.run_setitem_list_index(array, index_mod, value_np) + + def test_static_graph_tensor_index_setitem_muti_dim(self): + paddle.enable_static() + inps_shape = [3, 4, 5, 4] + array = np.arange( + self.numel(inps_shape), dtype='float32').reshape(inps_shape) + + index_shape = [2, 3, 4] + index1 = np.arange( + self.numel(index_shape), dtype='int32').reshape(index_shape) + index2 = np.arange( + self.numel(index_shape), dtype='int32').reshape(index_shape) + 2 + + value_shape = [4] + value_np = np.arange( + self.numel(value_shape), dtype='float32').reshape(value_shape) + 100 + for _ in range(3): + + index_mod1 = index1 % (min(array.shape)) + index_mod2 = index2 % (min(array.shape)) + + array2 = array.copy() + array2[index_mod1, index_mod2] = value_np + array3 = array.copy() + array3[index_mod1] = value_np + + program = paddle.static.Program() + with paddle.static.program_guard(program): + + x1 = paddle.static.data( + name='x1', shape=array.shape, dtype='float32') + x2 = paddle.static.data( + name='x2', shape=array.shape, dtype='float32') + + value = paddle.static.data( + name='value', shape=value_np.shape, dtype='float32') + index_1 = paddle.static.data( + name='index_1', shape=index1.shape, dtype='int32') + index_2 = paddle.static.data( + name='index_2', shape=index2.shape, dtype='int32') + + x1[index_1, index_2] = value + x2[index_1] = value + + place = paddle.fluid.CPUPlace( + ) if not paddle.fluid.core.is_compiled_with_cuda( + ) else paddle.fluid.CUDAPlace(0) + + prog = paddle.static.default_main_program() + exe = paddle.static.Executor(place) + + exe.run(paddle.static.default_startup_program()) + fetch_list = [x1.name, x2.name] + + setitem_pp = exe.run(prog, + feed={ + x1.name: array, + x2.name: array, + value.name: value_np, + index_1.name: index_mod1, + index_2.name: index_mod2 + }, + fetch_list=fetch_list) + self.assertTrue( + np.array_equal(array2, setitem_pp[0]), + msg='\n numpy:{},\n paddle:{}'.format(array2, + setitem_pp[0])) + self.assertTrue( + np.array_equal(array3, setitem_pp[1]), + msg='\n numpy:{},\n paddle:{}'.format(array3, + setitem_pp[1])) + array = array[0] + index1 = index1[0] + index2 = index2[0] + + def test_static_graph_array_index_muti_dim(self): + paddle.enable_static() + inps_shape = [3, 4, 5, 4] + array = np.arange( + self.numel(inps_shape), dtype='float32').reshape(inps_shape) + + index_shape = [2, 3, 4] + index1 = np.arange( + self.numel(index_shape), dtype='int32').reshape(index_shape) + index2 = np.arange( + self.numel(index_shape), dtype='int32').reshape(index_shape) + 2 + + for _ in range(3): + index_mod1 = index1 % (min(array.shape)) + index_mod2 = index2 % (min(array.shape)) + + array2 = array.copy() + array2[index_mod1, index_mod2] = 1 + y_np1 = array2[index_mod2, index_mod1] + array3 = array.copy() + array3[index_mod1] = 2.5 + y_np2 = array3[index_mod2] + + program = paddle.static.Program() + with paddle.static.program_guard(program): + + x1 = paddle.static.data( + name='x1', shape=array.shape, dtype='float32') + x2 = paddle.static.data( + name='x2', shape=array.shape, dtype='float32') + + x1[index_mod1, index_mod2] = 1 + x2[index_mod1] = 2.5 + y1 = x1[index_mod2, index_mod1] + y2 = x2[index_mod2] + place = paddle.fluid.CPUPlace( + ) if not paddle.fluid.core.is_compiled_with_cuda( + ) else paddle.fluid.CUDAPlace(0) + + prog = paddle.static.default_main_program() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + fetch_list = [x1.name, x2.name, y1.name, y2.name] + + setitem_pp = exe.run(prog, + feed={x1.name: array, + x2.name: array}, + fetch_list=fetch_list) + self.assertTrue( + np.array_equal(array2, setitem_pp[0]), + msg='\n numpy:{},\n paddle:{}'.format(array2, + setitem_pp[0])) + self.assertTrue( + np.array_equal(array3, setitem_pp[1]), + msg='\n numpy:{},\n paddle:{}'.format(array3, + setitem_pp[1])) + + self.assertTrue( + np.array_equal(y_np1, setitem_pp[2]), + msg='\n numpy:{},\n paddle:{}'.format(y_np1, setitem_pp[2])) + self.assertTrue( + np.array_equal(y_np2, setitem_pp[3]), + msg='\n numpy:{},\n paddle:{}'.format(y_np2, setitem_pp[3])) + array = array[0] + index1 = index1[0] + index2 = index2[0] + + def test_dygraph_array_index_muti_dim(self): + paddle.disable_static() + inps_shape = [3, 4, 5, 4] + array = np.arange( + self.numel(inps_shape), dtype='float32').reshape(inps_shape) + index_shape = [2, 3, 4] + index1 = np.arange( + self.numel(index_shape), dtype='int32').reshape(index_shape) + index2 = np.arange( + self.numel(index_shape), dtype='int32').reshape(index_shape) + 2 + + for _ in range(3): + + index_mod1 = index1 % (min(array.shape)) + index_mod2 = index2 % (min(array.shape)) + index_mod_t1 = paddle.to_tensor(index_mod1) + index_mod_t2 = paddle.to_tensor(index_mod2) + # 2 dim getitem + array1 = array.copy() + y_np1 = array1[index_mod2, index_mod1] + tensor1 = paddle.to_tensor(array) + + y_t1 = tensor1[index_mod_t2, index_mod_t1] + + self.assertTrue( + np.array_equal(y_t1.numpy(), y_np1), + msg='\n numpy:{},\n paddle:{}'.format(y_np1, y_t1.numpy())) + # 1 dim getitem + array2 = array.copy() + y_np2 = array2[index_mod2] + tensor2 = paddle.to_tensor(array) + + y_t2 = tensor2[index_mod_t2] + self.assertTrue( + np.array_equal(y_t2.numpy(), y_np2), + msg='\n numpy:{},\n paddle:{}'.format(y_np2, y_t2.numpy())) + + # 2 dim setitem + array1 = array.copy() + array1[index_mod1, index_mod2] = 1 + tensor1[index_mod_t1, index_mod_t2] = 1 + self.assertTrue( + np.array_equal(tensor1.numpy(), array1), + msg='\n numpy:{},\n paddle:{}'.format(array1, tensor1.numpy())) + # 1 dim setitem + array2 = array.copy() + + array2[index_mod1] = 2.5 + + tensor2[index_mod_t1] = 2.5 + + self.assertTrue( + np.array_equal(tensor2.numpy(), array2), + msg='\n numpy:{},\n paddle:{}'.format(array2, tensor2.numpy())) + + array = array[0] + index1 = index1[0] + index2 = index2[0] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 5bdf0451b54e91d96070020768a055547eef3fd5..3ae7d8cfd413ed5081709d0d6301ed4f97058398 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -16,10 +16,172 @@ import sys import numpy as np from . import unique_name from . import core +import paddle MAX_INTEGER = 2**31 - 1 +def is_list_tuple(index, contain_type): + def _is_list_tuple(item): + if not (isinstance(item, (list, tuple)) or type(item) == contain_type): + return False + if isinstance(item, (tuple, list)): + for s in item: + if not _is_list_tuple(s): + return False + return True + + if not isinstance(index, (tuple, list)): + return False + for s in index: + if not _is_list_tuple(s): + return False + return True + + +def is_one_dim_list(index, contain_type): + if isinstance(index, list): + for i in index: + if not isinstance(i, contain_type): + return False + else: + return False + return True + + +def get_list_index_shape(var_dims, index_dims): + var_dims_size = len(var_dims) + index_dims_size = len(index_dims) + + out_dims_size = var_dims_size - index_dims[0] + index_dims_size - 1 + + out_dims_shape = [1] * out_dims_size + + out_dims_shape[:index_dims_size - 1] = index_dims[1:] + + out_dims_shape[index_dims_size - 1:] = var_dims[index_dims[0]:] + return out_dims_shape + + +class SliceInfo: + def __init__(self): + self.pre_shape = None + self.indexes = [] + + def update(self, index): + if is_list_tuple(index, int) or isinstance(index, ( + paddle.fluid.Variable, np.ndarray)): + # convert index to Tensor + if not isinstance(index, paddle.fluid.Variable): + index = paddle.assign(index) + + self.indexes.append(index) + + if self.pre_shape is None: + self.pre_shape = index.shape + else: + if self.pre_shape != index.shape: + # broadcast + cur_shape = paddle.broadcast_shape(self.pre_shape, + index.shape) + for i in range(len(self.indexes)): + self.indexes[i] = paddle.broadcast_to(self.indexes[i], + cur_shape) + self.pre_shape = self.indexes[-1].shape + else: + raise ValueError( + "Index should be list/tuple of int or Tensor, but received {}.". + format(index)) + + def shape_stride(self, shape): + s = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + s[i] = shape[i + 1] * s[i + 1] + + return s + + def numel(self, shape): + return reduce(lambda x, y: x * y, shape) + + def get_offset_stride(self, tensor_shape): + for index in self.indexes: + if not isinstance(index, paddle.fluid.Variable): + raise ValueError( + "only support list/tensor index, but received {}.".format( + type(index))) + + if len(self.indexes) <= len(tensor_shape) or len(self.indexes) == 1: + shape = paddle.stack(self.indexes) + axes = list(range(1, len(self.pre_shape) + 1)) + [0, ] + + else: + raise ValueError( + "too many indices for tensor: tensor is {}-dimensional, but {} were indexed". + format(len(tensor_shape), self.pre_shape[0])) + + shape_transpose = paddle.transpose(shape, axes) + return shape_transpose + + def get_item(self, tensor): + shape_transpose = self.get_offset_stride(tensor.shape) + index = paddle.assign(shape_transpose) + return paddle.gather_nd(tensor, index) + + def set_item(self, tensor_origin, value): + + if not isinstance(value, paddle.fluid.Variable): + value = paddle.assign(value) + tensor_type = None + + if tensor_origin.dtype in [ + core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64 + ]: + tensor = tensor_origin + else: + tensor_type = tensor_origin.dtype + tensor = tensor_origin.astype(core.VarDesc.VarType.FP32) + + if value.dtype != tensor.dtype: + value = value.astype(tensor.dtype) + + shape_transpose = self.get_offset_stride(tensor_origin.shape) + index = paddle.assign(shape_transpose) + + gather_tensor_shape = get_list_index_shape( + tensor.shape, [len(self.indexes), ] + list(self.indexes[-1].shape)) + + value_dims_bd = [1, ] * len(gather_tensor_shape) + value_dims_bd[-len(value.shape):] = list(value.shape) + + for i in range(len(gather_tensor_shape)): + if not (value_dims_bd[i] == gather_tensor_shape[i] or + value_dims_bd[i] == 1): + raise ValueError("{} can not broadcast into {}".format( + value.shape, gather_tensor_shape)) + + value_broadcast = paddle.broadcast_to(value, gather_tensor_shape) + + value_1d = value_broadcast.reshape([-1] + gather_tensor_shape[len( + index.shape) - 1:]) + + index_1d = index.reshape([-1, index.shape[-1]]) + + tensor_stride = paddle.assign( + self.shape_stride(tensor.shape[:index.shape[-1]])) + inds = [] + for i in range(index_1d.shape[0]): + temp = (index_1d[i] * tensor_stride).sum() + inds.append(temp) + index_1d = paddle.stack(inds).reshape([-1]) + t_reshape = tensor.reshape([-1] + list(tensor.shape[index.shape[-1]:])) + out = paddle.scatter(t_reshape, index_1d, value_1d) + if tensor_type is not None: + out = out.astype(tensor_type) + tensor_origin[:] = out.reshape(tensor_origin.shape) + + return tensor_origin + + def replace_ellipsis(var, item): from .framework import Variable # Use slice(None) to replace Ellipsis. @@ -32,7 +194,9 @@ def replace_ellipsis(var, item): item = list(item) # Remove Variable to skip bug when counting Ellipsis - item_remove_var = [ele for ele in item if not isinstance(ele, Variable)] + item_remove_var = [ + ele for ele in item if not isinstance(ele, (Variable, np.ndarray)) + ] ell_count = item_remove_var.count(Ellipsis) if ell_count == 0: return item @@ -99,6 +263,9 @@ def _getitem_impl_(var, item): Sliced variable """ from .framework import default_main_program, Variable + if isinstance(item, list): + if not is_one_dim_list(item, int): + item = tuple(item) if not isinstance(item, tuple): item = (item, ) @@ -113,6 +280,7 @@ def _getitem_impl_(var, item): use_strided_slice = False item, none_axes = replace_none(item) item = replace_ellipsis(var, item) + slice_info = SliceInfo() for dim, slice_item in enumerate(item): if is_integer_or_scalar_tensor(slice_item): @@ -151,6 +319,11 @@ def _getitem_impl_(var, item): elif isinstance(slice_item, list): all_bool = True + + if is_list_tuple(slice_item, int): + slice_info.update(slice_item) + continue + for i in slice_item: if type(i) is int: all_bool = False @@ -188,35 +361,43 @@ def _getitem_impl_(var, item): idx = assign(np.array(slice_item).astype("int32")) return index_select(var, index=idx, axis=0) - elif isinstance(slice_item, Variable): - if len(item) != 1: - raise IndexError( - "When index contains a Tensor, its length must be 1, but received {}.". - format(len(item))) + elif isinstance(slice_item, np.ndarray): + slice_info.update(slice_item) + continue + elif isinstance(slice_item, (Variable)): + if len(item) == 1: - from ..tensor import index_select, gather_nd - from .layers.nn import where + from ..tensor import index_select, gather_nd + from .layers.nn import where - if slice_item.dtype == core.VarDesc.VarType.BOOL: - if len(slice_item.shape) > len(var.shape): - raise IndexError( - "The dims of bool index doesn't match indexed array, " - "the dims of bool index except to be equal or less " - "than {}, but received {}.".format( - len(var.shape), len(slice_item.shape))) - for i, dim_len in enumerate(slice_item.shape): - if dim_len != var.shape[i]: + if slice_item.dtype == paddle.bool: + if len(slice_item.shape) > len(var.shape): raise IndexError( - "The dimension of bool index doesn't match indexed array along "\ - "dimension {}, the target dimension is {}, but received {}.". - format(i, var.shape[i], dim_len)) - bool_2_idx = where(slice_item == True) - return gather_nd(var, bool_2_idx) - return index_select(var, index=slice_item, axis=0) + "The dims of bool index doesn't match indexed array, " + "the dims of bool index except to be equal or less " + "than {}, but received {}.".format( + len(var.shape), len(slice_item.shape))) + for i, dim_len in enumerate(slice_item.shape): + if dim_len != var.shape[i]: + raise IndexError( + "The dimension of bool index doesn't match indexed array along "\ + "dimension {}, the target dimension is {}, but received {}.". + format(i, var.shape[i], dim_len)) + bool_2_idx = where(slice_item == True) + return gather_nd(var, bool_2_idx) + else: + if len(slice_item.shape) == 1: + return index_select(var, index=slice_item, axis=0) + else: + slice_info.update(slice_item) + continue + else: + slice_info.update(slice_item) + continue else: raise IndexError( - "Valid index accept int or slice or ellipsis, but received {}.". + "Valid index accept int or slice or ellipsis or list, but received {}.". format(slice_item)) axes.append(dim) @@ -225,6 +406,13 @@ def _getitem_impl_(var, item): steps.append(step) use_strided_slice = True if step != 1 else use_strided_slice + if slice_info.indexes: + if len(slice_info.indexes) != len(item): + raise IndexError( + "Valid index accept int or slice or ellipsis or list, but received {}.". + format(item)) + return slice_info.get_item(var) + inputs = {'Input': [var]} attrs = { 'axes': axes, @@ -298,7 +486,9 @@ def _setitem_impl_(var, item, value): from .framework import default_main_program, Variable inputs = {'Input': var} - + if isinstance(item, list): + if not is_one_dim_list(item, int): + item = tuple(item) # 1. Parse item if not isinstance(item, tuple): item = (item, ) @@ -311,7 +501,7 @@ def _setitem_impl_(var, item, value): item, none_axes = replace_none(item) item = replace_ellipsis(var, item) - + slice_info = SliceInfo() dim = 0 for _, slice_item in enumerate(item): if is_integer_or_scalar_tensor(slice_item): @@ -319,6 +509,16 @@ def _setitem_impl_(var, item, value): start = slice_item end = slice_item + 1 if slice_item != -1 else MAX_INTEGER step = 1 + elif isinstance(slice_item, list): + if not is_list_tuple(slice_item, int): + raise TypeError( + "Only support int or list in index list. But revceived {}.". + format(slice_item)) + slice_info.update(slice_item) + continue + elif isinstance(slice_item, (Variable, np.ndarray)): + slice_info.update(slice_item) + continue elif isinstance(slice_item, slice): start = slice_item.start @@ -358,7 +558,12 @@ def _setitem_impl_(var, item, value): steps.append(step) dim += 1 - + if slice_info.indexes: + if len(slice_info.indexes) != len(item): + raise IndexError( + "Valid index accept int or slice or ellipsis or list, but received {}.". + format(item)) + return slice_info.set_item(var, value) attrs = { 'axes': axes, 'starts': starts,