From 697c712f3df0d811eba27f8c49ed13899038ccac Mon Sep 17 00:00:00 2001 From: JYChen Date: Fri, 4 Aug 2023 14:40:50 +0800 Subject: [PATCH] Support Combined indexing for __getitem__ and __setitem__ (#55211) * WIP: start writing combined indexing get * list/tuple/Variable * getitem 80% * add setitem * add some unittest for setitem * lazy import * fix some setitem error * fix advance indexing with decreasing axes; fix strided_slice input name * combine int-tensor getitem is ok (without boolean support & broadcast); add getitem unittest for static * add broadcast & parse bool tensor for __getitem * [change getitem] _getitem_impl_ to _getitem_static, not deleting the former one * refine new getitem; fix ut in variable/var_base * add __getitem__ ut in dygraph * re-dispatch getitem for Py/CPP; fix strided_slice decrease axes error in dygraph * fix ut; support tensor in slice * [change setitem] _setitem_impl_ to _setitem_static, not deleting the former one * remove some UT (for some, temporarily) * add IndexError to solve timeout problem in static-mode * 1.temply forbideen all-False bool-indexput; 2.setitem_static will return new variable * xpu uses old stratege * rename dy2st setitem ut to avoid same-name problem * dy2st for new combined index * ut case for combine-index with dy2st * open ut with all-false-bool setitem * remove useless doc and _getitem_impl_ * change static res * fix static xpu --- paddle/fluid/pybind/eager_method.cc | 3 + .../fluid/dygraph/tensor_patch_methods.py | 59 +- python/paddle/fluid/framework.py | 9 +- python/paddle/fluid/variable_index.py | 811 ++++++++++++------ python/paddle/static/input.py | 9 +- test/CMakeLists.txt | 1 + .../{test_setitem.py => test_jit_setitem.py} | 19 + test/indexing/CMakeLists.txt | 9 + test/indexing/test_getitem.py | 328 +++++++ test/indexing/test_setitem.py | 139 +++ test/legacy_test/test_set_value_op.py | 8 - test/legacy_test/test_var_base.py | 6 +- test/legacy_test/test_variable.py | 10 - test/legacy_test/test_while_loop_op.py | 4 +- 14 files changed, 1082 insertions(+), 333 deletions(-) rename test/dygraph_to_static/{test_setitem.py => test_jit_setitem.py} (90%) create mode 100644 test/indexing/CMakeLists.txt create mode 100644 test/indexing/test_getitem.py create mode 100644 test/indexing/test_setitem.py diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 4f70ce98249..cb11759af33 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1120,6 +1120,9 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, eager_gil_scoped_release guard; out = strided_slice_ad_func( self->tensor, slice_axes, slice_starts, slice_ends, slice_strides); + if (!decrease_axis_tmp.empty()) { + out = squeeze_ad_func(out, decrease_axis_tmp); + } } else { PADDLE_THROW(platform::errors::InvalidArgument( "Slice is only support slice and strided_slice, but we got %s which " diff --git a/python/paddle/fluid/dygraph/tensor_patch_methods.py b/python/paddle/fluid/dygraph/tensor_patch_methods.py index 21ff77f7cfe..a294bc6de4e 100644 --- a/python/paddle/fluid/dygraph/tensor_patch_methods.py +++ b/python/paddle/fluid/dygraph/tensor_patch_methods.py @@ -26,7 +26,8 @@ from .. import unique_name from ..framework import ( Variable, Parameter, - _getitem_impl_, + _getitem_static, + _setitem_static, _setitem_impl_, EagerParamBase, in_dygraph_mode, @@ -726,47 +727,34 @@ def monkey_patch_tensor(): return True return False - def __getitem__(self, item): - def is_list_tuple(index, contain_type): - def _is_list_tuple(item): - if isinstance(item, (tuple, list)): - for s in item: - if not _is_list_tuple(s): - return False - else: - if type(item) != contain_type: - return False + def contain_tensor_or_list(item): + if not isinstance(item, tuple): + item = (item,) + + for slice_item in item: + if isinstance(slice_item, (list, np.ndarray, Variable)): return True + elif isinstance(slice_item, slice): + if ( + isinstance(slice_item.start, Variable) + or isinstance(slice_item.stop, Variable) + or isinstance(slice_item.step, Variable) + ): + return True - if not isinstance(index, (tuple, list)): - return False - for s in index: - if not _is_list_tuple(s): - return False - return True + return False - if contain_tensor(item) or is_list_tuple(item, int): + def __getitem__(self, item): + if contain_tensor_or_list(item): # 1. Call _getitem_impl_ when item contains tensor. # Why not call a c++ function ? Because item can't be parsed when it contains tensor. - return _getitem_impl_(self, item) + return _getitem_static(self, item) else: # 2. Call c++ func getitem_index_not_tensor to speedup. return self._getitem_index_not_tensor(item) def __setitem__(self, item, value): - def contain_tensor_or_list(item): - if not isinstance(item, tuple): - item = [item] - - for slice_item in item: - if isinstance(slice_item, list): - return True - elif isinstance(slice_item, Variable): - return True - - return False - def is_combine_index(item): var_type = None item_type = None @@ -788,10 +776,13 @@ def monkey_patch_tensor(): return False - if contain_tensor_or_list(item) and not is_combine_index(item): + if contain_tensor_or_list(item): + if core.is_compiled_with_xpu() and not is_combine_index(item): + # (NOTE): Currently, there is no index_put_xpu kernel. + return _setitem_impl_(self, item, value) # To reuse code with static graph, - # Call _setitem_impl_ when item contains tensor or list. - return _setitem_impl_(self, item, value) + # Call _setitem_static when item contains tensor or list. + return _setitem_static(self, item, value) else: return self.__setitem_eager_tensor__(item, value) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index efb038669e8..29389ea61af 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -37,7 +37,7 @@ from . import unique_name import paddle.version as fluid_version import warnings import functools -from .variable_index import _getitem_impl_, _setitem_impl_ +from .variable_index import _getitem_static, _setitem_static, _setitem_impl_ import threading __all__ = [ @@ -2294,13 +2294,16 @@ class Variable(metaclass=VariableMetaClass): raise IndexError("Valid index accept int or slice or tuple") def __getitem__(self, item): - return _getitem_impl_(self, item) + return _getitem_static(self, item) def __setitem__(self, item, value): from .dygraph.base import in_declarative_mode if in_declarative_mode(): - return _setitem_impl_(self, item, value) + if is_compiled_with_xpu(): + # (NOTE): Currently, there is no index_put_xpu kernel. + return _setitem_impl_(self, item, value) + return _setitem_static(self, item, value) else: raise RuntimeError( "In static mode, the __setitem__ (looks like: x[indices] = values) should not be used. Please use x = paddle.static.setitem(x, indices, values)" diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 2bea3bbdb9d..acf30532fbe 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -19,6 +19,7 @@ from . import core import paddle import warnings + MAX_INTEGER = 2**31 - 1 @@ -294,7 +295,7 @@ def is_integer_or_scalar_tensor(ele): "1-D Tensor will be treat as advanced indexing in future version. Currently, 1-D Tensor means a scalar, not vector, and please modify it to 0-D Tensor. If advanced indexing is needed, please use `export FLAGS_set_to_1d=False` to set the flag." ) return True - if len(ele.shape) == 0: + if len(ele.shape) == 0 and ele.dtype != paddle.bool: return True return False @@ -361,273 +362,6 @@ def get_value_for_bool_tensor(var, item): ) -def _getitem_impl_(var, item): - """ - Slice the variable. - - Args: - item(int/slice/tuple) : the index. - - Returns: - Sliced variable - """ - from .framework import default_main_program, Variable - from .dygraph.base import in_declarative_mode - - if in_declarative_mode() and hasattr(var, "is_view_var"): - var.is_view_var = True - - if isinstance(item, list): - if not is_one_dim_list(item, int): - item = tuple(item) - - if not isinstance(item, tuple): - item = (item,) - - decrease_axes = [] - axes = [] - starts = [] - ends = [] - steps = [] - reverse_axes = [] - - use_strided_slice = False - item = replace_ndarray(item) - item = replace_ellipsis(var, item) - item, none_axes = replace_none(item) - slice_info = SliceInfo() - is_tensor_array = ( - hasattr(var, "desc") - and var.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY - ) - - for dim, slice_item in enumerate(item): - if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor( - slice_item - ): - if ( - not is_tensor_array - and isinstance(slice_item, int) - and var.shape[dim] is not None - and var.shape[dim] >= 0 - and slice_item >= var.shape[dim] - ): - # For python, if users write a, b = var, the __getitem__ - # method will iterate through 0, 1, 2 ... until __getitem__ - # throws an IndexError, then stop. The var[0], var[1] will - # be given to a, b respectively. If more values are given, - # the unpack size would cause error. - # We raises IndexError here to support grammar like `a, b = var` - raise IndexError( - "slice_item %d at dim %d should be >= 0 and < var.shape[%d]: %d" - % (slice_item, dim, dim, var.shape[dim]) - ) - decrease_axes.append(dim) - start = slice_item - step = 1 - end = slice_item + 1 if slice_item != -1 else MAX_INTEGER - - elif isinstance(slice_item, slice): - start = slice_item.start - end = slice_item.stop - step = slice_item.step - - if start is None and end is None and step is None: - continue - - step = 1 if step is None else step - - if start is None: - start = 0 if step > 0 else MAX_INTEGER - if end is None: - if ( - paddle.in_dynamic_mode() or not is_tensor_array - ) and var.shape[dim] != -1: - end = var.shape[dim] if step > 0 else -1 - else: - end = MAX_INTEGER if step > 0 else -1 - - elif isinstance(slice_item, list): - all_bool = True - - if is_list_tuple(slice_item, int): - slice_info.update(slice_item) - continue - - for i in slice_item: - if type(i) is int: - all_bool = False - elif not isinstance(i, bool): - raise TypeError("Only support int or bool in index list.") - - if len(item) != 1: - raise IndexError( - "When index contains a list, its length must be 1, but received {}.".format( - len(item) - ) - ) - new_slice_item = [] - if all_bool: - if len(slice_item) != var.shape[0]: - raise IndexError( - "The dimension of bool index doesn't match indexed array along " - "dimension 0, the target dimension is {}, but received {}.".format( - var.shape[0], len(slice_item) - ) - ) - for idx, ele in enumerate(slice_item): - if ele is True: - new_slice_item.append(idx) - slice_item = new_slice_item - else: - for idx, ele in enumerate(slice_item): - if type(ele) is int: - new_slice_item.append(ele) - elif ele is True: - new_slice_item.append(1) - else: - new_slice_item.append(0) - slice_item = new_slice_item - - from ..tensor import index_select - - idx = paddle.assign(np.array(slice_item).astype("int32")) - return index_select(var, index=idx, axis=0) - - elif isinstance(slice_item, (Variable, core.eager.Tensor)): - if len(item) == 1: - from ..tensor import index_select - - if slice_item.dtype == paddle.bool: - if in_declarative_mode(): - tmp = get_value_for_bool_tensor(var, slice_item) - if hasattr(tmp, "is_view_var"): - tmp.is_view_var = True - return tmp - else: - return get_value_for_bool_tensor(var, slice_item) - else: - if len(slice_item.shape) == 1: - return index_select(var, index=slice_item, axis=0) - else: - slice_info.update(slice_item) - continue - else: - slice_info.update(slice_item) - continue - - else: - raise IndexError( - "Valid index accept int or slice or ellipsis or list, but received {}.".format( - slice_item - ) - ) - - axes.append(dim) - starts.append(start) - ends.append(end) - steps.append(step) - use_strided_slice = True if step != 1 else use_strided_slice - - if slice_info.indexes: - if len(slice_info.indexes) != len(item): - raise IndexError( - "Valid index accept int or slice or ellipsis or list, but received {}.".format( - item - ) - ) - if in_declarative_mode(): - tmp = slice_info.get_item(var) - if hasattr(tmp, "is_view_var"): - tmp.is_view_var = True - return tmp - else: - return slice_info.get_item(var) - - inputs = {'Input': [var]} - attrs = { - 'axes': axes, - 'starts': [], - 'ends': [], - 'decrease_axis': decrease_axes, - } - if use_strided_slice: - attrs['strides'] = [] - - infer_flags = [1] * len(axes) - deal_attrs(attrs, starts, "starts", "StartsTensorList", inputs, infer_flags) - deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags) - deal_attrs( - attrs, steps, "strides", "StridesTensorList", inputs, infer_flags - ) - attrs['infer_flags'] = infer_flags - - out = var - if len(axes) > 0: - op_type = "strided_slice" if use_strided_slice else "slice" - if paddle.in_dynamic_mode() and op_type == "slice": - if "StartsTensorList" in inputs.keys(): - st = inputs['StartsTensorList'] - else: - st = attrs['starts'] - if "EndsTensorList" in inputs.keys(): - end = inputs['EndsTensorList'] - else: - end = attrs['ends'] - out = paddle._C_ops.slice( - var, axes, st, end, attrs['infer_flags'], attrs['decrease_axis'] - ) - else: - target_block = default_main_program().current_block() - - slice_out_var = target_block.create_var( - name=unique_name.generate_with_ignorable_key( - var.name + "_" + op_type - ), - dtype=var.dtype, - ) - target_block.append_op( - type=op_type, - inputs=inputs, - outputs={'Out': [slice_out_var]}, - attrs=attrs, - ) - out = slice_out_var - - if len(reverse_axes) > 0: - from .layers.tensor import reverse - - out = reverse(out, axis=reverse_axes) - - # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D - # with FLAGS_set_to_1d=True. In this case, one `None` should be pop out, - # otherwise the output shape will be not correct. - set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d'] - if set_to_1d and len(decrease_axes) == len(var.shape): - warnings.warn( - "Warning: In Tensor '__getitem__', if the number of scalar elements in the index is equal to the rank of the Tensor, the output should be 0-D. In order to be consistent with the behavior of previous versions, it will be processed to 1-D. But it is not correct and will be removed in release 2.6. If 1-D is still wanted, please modify the index element from scalar to slice (e.g. 'x[i]' => 'x[i:i+1]')." - ) - none_axes = none_axes[1:] - - if len(none_axes) > 0: - # Deal with cases that decrease_axes is not empty - # For example: - # # x.shape: (2,3,4) - # out = x[0, 0:2, None] # out.shape : (2, 1, 4) - for idx, axis in enumerate(none_axes): - l = len([i for i in decrease_axes if i < axis]) - new_axis = axis - l - none_axes[idx] = new_axis - - from ..tensor import unsqueeze - - out = unsqueeze(out, axis=none_axes) - - if in_declarative_mode() and hasattr(out, "is_view_var"): - out.is_view_var = True - return out - - def _setitem_for_tensor_array(var, item, value): """branches for tensor array setitem operation. A item can be a: @@ -662,8 +396,8 @@ def _setitem_for_tensor_array(var, item, value): def _setitem_impl_(var, item, value): - from .framework import default_main_program, Variable from paddle.fluid import core + from .framework import default_main_program, Variable if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: return _setitem_for_tensor_array(var, item, value) @@ -898,3 +632,542 @@ def set_value_for_bool_tensor(var, item, value): ) return var + + +def deal_advanced_index(ori_tensor, indices, is_for_setitem): + """ + Transpose origin Tensor and advanced indices to the front. + + Returns: + transed_tensor (Tensor): transposed tensor, corresbonding with advanced indices + transed_index (List): advanced indices transed to the front + trans_back_dim (List): order of axes to transpose back to original order. Only used in __setitem__. + pos_of_new_dim (int): axis of new dim in the result. Only used in __getitem__. + rank_of_new_dim (int): rank of new dim in the result. Only used in __getitem__. + """ + transed_dim = [] + transed_index = [] + + # These flags indicates whether the result get by gather_nd requires a second transpose. + # Only used in __getitem__. + pos_of_new_dim = MAX_INTEGER + rank_of_new_dim = 1 + + for i, indice in enumerate(indices): + if indice is not None: + if not is_for_setitem: + if i == 0: + # case 1: advanced indices at axis 0, the new dim will be at first. + pos_of_new_dim = 0 + if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1: + # case 2: there are not adjacent advanced indices, the new dim will be at first. + pos_of_new_dim = 0 + else: + pos_of_new_dim = min(pos_of_new_dim, i) + rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim) + transed_dim.append(i) + transed_index.append(indice[1]) + for i in range(ori_tensor.ndim): + if indices[i] is None: + transed_dim.append(i) + transed_tensor = ori_tensor.transpose(transed_dim) + + trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else [] + + return ( + transed_tensor, + transed_index, + trans_back_dim, + pos_of_new_dim, + rank_of_new_dim, + ) + + +def parse_index(x, indices): + advanced_index = [None] * 2 * len(x.shape) # content is (dim, index) + # for set_value / slice / strided_slice OP + decrease_axes = [] + axes = [] + starts = [] + ends = [] + steps = [] + use_strided_slice = False + has_advanced_index = False + + if isinstance(indices, list) and not is_one_dim_list(indices, int): + indices = tuple(indices) + + if not isinstance(indices, tuple): + indices = (indices,) + + indices = replace_ndarray(indices) + indices = replace_ellipsis(x, indices) + indices, none_axes = replace_none(indices) + + is_tensor_array = ( + hasattr(x, "desc") + and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY + ) + + estimated_dim = 0 + for dim, slice_item in enumerate(indices): + start, end, step = None, None, None + if is_integer_or_scalar_tensor(slice_item): + if ( + not is_tensor_array + and isinstance(slice_item, int) + and x.shape[dim] is not None + and x.shape[dim] >= 0 + and slice_item >= x.shape[dim] + ): + # For python, if users write a, b = var, the __getitem__ + # method will iterate through 0, 1, 2 ... until __getitem__ + # throws an IndexError, then stop. The var[0], var[1] will + # be given to a, b respectively. If more values are given, + # the unpack size would cause error. + # We raises IndexError here to support grammar like `a, b = var` + raise IndexError( + "slice_item %d at dim %d should be >= 0 and < x.shape[%d]: %d" + % (slice_item, dim, dim, x.shape[dim]) + ) + # not calculate result to reduce call times for slice OP. + decrease_axes.append(dim) + start = slice_item + step = 1 + end = slice_item + 1 if slice_item != -1 else MAX_INTEGER + elif isinstance(slice_item, bool): + # single bool is advanced-indexing + none_axes.append(dim) + estimated_dim += 1 + advanced_index[estimated_dim] = ( + estimated_dim, + paddle.to_tensor(slice_item), + ) + has_advanced_index = True + elif isinstance(slice_item, slice): + start = slice_item.start + end = slice_item.stop + step = slice_item.step + estimated_dim += 1 + + if start is None and end is None and step is None: + continue + + step = 1 if step is None else step + if start is None: + start = 0 if step > 0 else MAX_INTEGER + if end is None: + end = MAX_INTEGER if step > 0 else -1 + + elif isinstance(slice_item, (list, tuple)): + advanced_index[estimated_dim] = ( + estimated_dim, + paddle.to_tensor(slice_item), + ) + + if ( + advanced_index[estimated_dim][1].dtype == paddle.bool + and len(slice_item) != x.shape[dim] + ): + raise IndexError( + "The shape of boolean index {} did not match indexed tensor {} along axis {}".format( + len(slice_item), x.shape[dim], dim + ) + ) + + has_advanced_index = True + estimated_dim += 1 + + elif isinstance(slice_item, paddle.fluid.Variable): + # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. + if slice_item.dtype == paddle.bool: + if slice_item.ndim == 0: + # 0-D bool Tensor, same as single PY-bool. + none_axes.append(dim) + + elif slice_item.shape[0] != x.shape[dim]: + raise IndexError( + "The shape of boolean index {} did not match indexed tensor {} along axis {}".format( + slice_item.shape[0], x.shape[dim], dim + ) + ) + advanced_index[estimated_dim] = (estimated_dim, slice_item) + has_advanced_index = True + estimated_dim += 1 + + else: + raise IndexError( + "Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format( + slice_item + ) + ) + if not (start is None or end is None or step is None): + starts.append(start) + ends.append(end) + steps.append(step) + axes.append(dim) + use_strided_slice = ( + True + if (isinstance(step, paddle.fluid.Variable) or step != 1) + else use_strided_slice + ) + return ( + starts, + ends, + steps, + axes, + none_axes, + decrease_axes, + advanced_index, + has_advanced_index, + use_strided_slice, + ) + + +def _setitem_static(x, indices, values): + """ + In dynamic mode, this function will modify the value at input tensor, returning same Tensor as input. + But it will return a new Tensor with assigned value in static mode. + + Args: + x(Tensor): Tensor to be set value. + indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. + values(Tensor|Number|Ndarray): values to be assigned to the x. + """ + from .framework import default_main_program, Variable + + if x.type == paddle.fluid.core.VarDesc.VarType.LOD_TENSOR_ARRAY: + return _setitem_for_tensor_array(x, indices, values) + + # step1: parsing the index and recording them + ( + starts, + ends, + steps, + axes, + none_axes, + decrease_axes, + advanced_index, + has_advanced_index, + use_strided_slice, + ) = parse_index(x, indices) + + inputs = {'Input': x} + attrs = { + 'axes': axes, + 'starts': starts, + 'ends': ends, + 'steps': steps, + 'decrease_axes': decrease_axes, + 'none_axes': none_axes, + } + if paddle.utils._contain_var(starts): + inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list( + starts + ) + del attrs['starts'] + if paddle.utils._contain_var(ends): + inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends) + del attrs['ends'] + if paddle.utils._contain_var(steps): + inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps) + del attrs['steps'] + + if not has_advanced_index: + # step2. Parse values + dtype = x.dtype + attrs['dtype'] = dtype + + from .data_feeder import convert_dtype + + if isinstance(values, (bool, int, float, complex)): + values = np.array([values]).astype(convert_dtype(dtype)) + + if isinstance(values, np.ndarray): + shape = list(values.shape) + values = values.ravel().tolist() + attrs["values"] = values + attrs["shape"] = shape + + elif isinstance(values, Variable): + inputs["ValueTensor"] = values + else: + raise TypeError( + "Only support to assign an integer, float, numpy.ndarray or " + "paddle.Tensor to a paddle.Tensor, but received {}".format( + type(values) + ) + ) + + # step3.1: Only basic indexing, use OP set_value to set value. + if paddle.in_dynamic_mode(): + x._bump_inplace_version() + output = x + else: + helper = paddle.fluid.layer_helper.LayerHelper( + 'set_value', **locals() + ) + output = helper.create_variable_for_type_inference(dtype=x.dtype) + cur_block = default_main_program().current_block() + cur_block.append_op( + type="set_value", + inputs=inputs, + outputs={'Out': output}, + attrs=attrs, + inplace_map={"Input": "Out"}, + ) + + if not paddle.in_dynamic_mode(): + # map var to the new output + paddle.jit.api.ProgramTranslator.get_instance()._params_map.add( + cur_block.program, x.desc.id(), output + ) + return output + else: + # step3.2: Case for there are advanced indexing. + # 1. get __getitem__ result of basic indexing; + # 2. transpose original tensor so that the axis with advanced indexing will come to the first; + # 3. assign values to the sliced result by index_put OP; + # 4. transpose back and assign the result to original tensor by set_value OP. + + sub_tensor = get_tensor_with_basic_indexing( + x, + axes, + starts, + ends, + steps, + decrease_axes, + none_axes, + use_strided_slice, + ) + ( + transed_sub_tensor, + adjusted_advanced_index, + transback_dim, + _, + _, + ) = deal_advanced_index(sub_tensor, advanced_index, True) + if not isinstance(values, Variable): + values = paddle.assign(values).astype(transed_sub_tensor.dtype) + transed_sub_tensor = transed_sub_tensor.index_put( + adjusted_advanced_index, values + ) + + # NOTE(zoooo0820): now basic indexing of __getitem__ will return a new Tensor both in dynamic and static mode + # After strided is ready and basic indexing returns view of Tensor in dynamic mode. The code shoule be changed + # for dynamic mode. + if paddle.in_dynamic_mode(): + transed_sub_tensor.index_put_(adjusted_advanced_index, values) + else: + transed_sub_tensor = transed_sub_tensor.index_put( + adjusted_advanced_index, values + ) + + transback_sub_tensor = transed_sub_tensor.transpose(transback_dim) + + inputs["ValueTensor"] = transback_sub_tensor + if paddle.in_dynamic_mode(): + x._bump_inplace_version() + output = x + else: + helper = paddle.fluid.layer_helper.LayerHelper( + 'set_value', **locals() + ) + output = helper.create_variable_for_type_inference(dtype=x.dtype) + cur_block = default_main_program().current_block() + cur_block.append_op( + type="set_value", + inputs=inputs, + outputs={'Out': output}, + attrs=attrs, + inplace_map={"Input": "Out"}, + ) + if not paddle.in_dynamic_mode(): + # map var to the new output + paddle.jit.api.ProgramTranslator.get_instance()._params_map.add( + cur_block.program, x.desc.id(), output + ) + return output + + +def get_tensor_with_basic_indexing( + x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice +): + from .dygraph.base import in_declarative_mode + + if in_declarative_mode() and hasattr(x, "is_view_var"): + x.is_view_var = True + + if len(axes) == 0: + out = x + else: + op_type = "strided_slice" if use_strided_slice else "slice" + inputs = {'Input': [x]} + attrs = { + 'axes': axes, + 'starts': [], + 'ends': [], + 'decrease_axis': decrease_axes, + } + if use_strided_slice: + attrs['strides'] = [] + infer_flags = [1] * len(axes) + deal_attrs( + attrs, starts, "starts", "StartsTensorList", inputs, infer_flags + ) + deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags) + deal_attrs( + attrs, steps, "strides", "StridesTensorList", inputs, infer_flags + ) + attrs['infer_flags'] = infer_flags + + if paddle.in_dynamic_mode(): + if "StartsTensorList" in inputs.keys(): + st = inputs['StartsTensorList'] + else: + st = attrs['starts'] + if "EndsTensorList" in inputs.keys(): + end = inputs['EndsTensorList'] + else: + end = attrs['ends'] + if "StridesTensorList" in inputs.keys(): + stride = inputs['StridesTensorList'] + else: + stride = attrs['strides'] + if use_strided_slice: + out = paddle._C_ops.strided_slice(x, axes, st, end, stride) + if len(decrease_axes) > 0: + out = paddle._C_ops.squeeze(out, decrease_axes) + else: + out = paddle._C_ops.slice( + x, + axes, + st, + end, + attrs['infer_flags'], + attrs['decrease_axis'], + ) + else: + from .framework import default_main_program + + target_block = default_main_program().current_block() + + slice_out_var = target_block.create_var( + name=unique_name.generate_with_ignorable_key( + x.name + "_" + op_type + ), + dtype=x.dtype, + ) + target_block.append_op( + type=op_type, + inputs=inputs, + outputs={'Out': [slice_out_var]}, + attrs=attrs, + ) + out = slice_out_var + # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D + # with FLAGS_set_to_1d=True. In this case, one `None` should be pop out, + # otherwise the output shape will be not correct. + set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d'] + if set_to_1d and len(decrease_axes) == len(x.shape): + warnings.warn( + "Warning: In Tensor '__getitem__', if the number of scalar elements in the index is equal to the rank of the Tensor, the output should be 0-D. In order to be consistent with the behavior of previous versions, it will be processed to 1-D. But it is not correct and will be removed in release 2.6. If 1-D is still wanted, please modify the index element from scalar to slice (e.g. 'x[i]' => 'x[i:i+1]')." + ) + none_axes = none_axes[1:] + + if len(none_axes) > 0: + # Deal with cases that decrease_axes is not empty + # For example: + # # x.shape: (2,3,4) + # out = x[0, 0:2, None] # out.shape : (2, 1, 4) + for idx, axis in enumerate(none_axes): + l = len([i for i in decrease_axes if i < axis]) + new_axis = axis - l + none_axes[idx] = new_axis + + out = paddle.unsqueeze(out, axis=none_axes) + + if in_declarative_mode() and hasattr(out, "is_view_var"): + out.is_view_var = True + return out + + +def _getitem_static(x, indices): + """ + Args: + x(Tensor): Tensor to be indexing. + indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. + """ + # step1: parsing the index and recording them + ( + starts, + ends, + steps, + axes, + none_axes, + decrease_axes, + advanced_index, + has_advanced_index, + use_strided_slice, + ) = parse_index(x, indices) + + # step2: Dealing with basic indexing + out = get_tensor_with_basic_indexing( + x, + axes, + starts, + ends, + steps, + decrease_axes, + none_axes, + use_strided_slice, + ) + + # step3: Dealing with advanced indexing + if has_advanced_index: + ( + transed_tensor, + adjusted_advanced_index, + _, + pos_of_new_dim, + rank_of_new_dim, + ) = deal_advanced_index(out, advanced_index, False) + + # TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently + if ( + len(adjusted_advanced_index) == 1 + and adjusted_advanced_index[0].dtype == paddle.bool + ): + # Note: now slice not support 0-size Tensor, so only one bool tensor can return empty 0-size. + out = get_value_for_bool_tensor( + transed_tensor, adjusted_advanced_index[0] + ) + else: + adjusted_advanced_index = parse_bool_and_broadcast_indices( + adjusted_advanced_index + ) + advanced_index_tensor = paddle.stack( + adjusted_advanced_index, axis=-1 + ) + out = paddle.gather_nd(transed_tensor, advanced_index_tensor) + + if pos_of_new_dim != 0: + perm = ( + list(range(pos_of_new_dim, pos_of_new_dim + rank_of_new_dim)) + + list(range(0, pos_of_new_dim)) + + list(range(pos_of_new_dim + rank_of_new_dim, out.ndim)) + ) + out = out.transpose(perm) + + return out + + +def parse_bool_and_broadcast_indices(indices): + # deal with multiple Tensors and translating bool tensor to int tensor. + # In static mode, bool-tensor cannot be broadcasted since its corressponding int tensor's shape cannot be infered. + for i, indice in enumerate(indices): + if indice.dtype == paddle.bool: + indices[i] = paddle.nonzero(indice)[:, 0] + if len(indices) > 1: + indices = paddle.broadcast_tensors(indices) + return indices diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index ab8f80c8879..4f856227bda 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -18,7 +18,7 @@ from paddle.fluid.data_feeder import check_type from paddle.fluid.framework import convert_np_dtype_to_dtype_, static_only from paddle.fluid.layer_helper import LayerHelper -from ..fluid.variable_index import _setitem_impl_ +from ..fluid.variable_index import _setitem_impl_, _setitem_static __all__ = [] @@ -367,5 +367,8 @@ def setitem(x, index, value): (1) a[Tensor([10,10])]=v -> setitem(a, (Tensor([10,10]),), v) (2) a[1] = v -> setitem(a, (1,), v) """ - - return _setitem_impl_(x, index, value) + if core.is_compiled_with_xpu(): + # (NOTE): Currently, there is no index_put_xpu kernel. + return _setitem_impl_(x, index, value) + else: + return _setitem_static(x, index, value) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 20de0a13a6f..19bd7f8d83e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -127,6 +127,7 @@ if(WITH_TESTING) add_subdirectory(ipu) endif() add_subdirectory(ir) + add_subdirectory(indexing) add_subdirectory(legacy_test) if(WITH_MKLDNN) add_subdirectory(mkldnn) diff --git a/test/dygraph_to_static/test_setitem.py b/test/dygraph_to_static/test_jit_setitem.py similarity index 90% rename from test/dygraph_to_static/test_setitem.py rename to test/dygraph_to_static/test_jit_setitem.py index 93b8c5d7936..374d0569c59 100644 --- a/test/dygraph_to_static/test_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -178,5 +178,24 @@ class TestCase11(TestSetItemBase): return y, x_grad, value_grad +class TestCase12(TestSetItemBase): + # Test combind-indexing + def init_func(self): + def foo(x, value): + y = x + 1 + y[[0, 1], 1, :2] = value + return y + + return foo + + def run_dygrah(self, func): + x = self.init_data() + value = paddle.ones((32,)) + value.stop_gradient = False + y = func(x, value) + x_grad, value_grad = paddle.grad(y, [x, value]) + return y, x_grad, value_grad + + if __name__ == '__main__': unittest.main() diff --git a/test/indexing/CMakeLists.txt b/test/indexing/CMakeLists.txt new file mode 100644 index 00000000000..95739040ef4 --- /dev/null +++ b/test/indexing/CMakeLists.txt @@ -0,0 +1,9 @@ +file( + GLOB TEST_OPS + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach() diff --git a/test/indexing/test_getitem.py b/test/indexing/test_getitem.py new file mode 100644 index 00000000000..bb18a0b7723 --- /dev/null +++ b/test/indexing/test_getitem.py @@ -0,0 +1,328 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.fluid.variable_index import _getitem_static + + +class TestGetitemInDygraph(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_combined_index_1(self): + # int tensor + slice (without decreasing axes) + np_data = np.random.randn(3, 4, 5, 6) + np_res = np_data[[0, 1], :, [1, 2]] + x = paddle.to_tensor(np_data) + y = x[[0, 1], :, [1, 2]] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_2(self): + # int tensor + slice (with decreasing axes) + np_data = np.random.randn(3, 4, 5, 6) + x = paddle.to_tensor(np_data) + + np_res = np_data[:, 1, [1, 2], 0] + y = x[:, 1, [1, 2], 0] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_3(self): + # multiple int tensors, with one int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[[1, 0], :, [1, 4], 1:5:2, 4] + + x = paddle.to_tensor(np_data) + y = x[[1, 0], :, [1, 4], 1:5:2, 4] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_4(self): + # multiple not adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[:, [1, 0], 0:4:2, [2, 3], 4] + x = paddle.to_tensor(np_data) + y = x[:, [1, 0], 0:4:2, [2, 3], 4] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_5(self): + # multiple adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [1, 0], [2, 3], 0:4:2] + x = paddle.to_tensor(np_data) + y = x[::2, [1, 0], [2, 3], 0:4:2] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_6(self): + # multiple adjacent and not adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [1, 0], [2, 3], 0:4:2, [4, 6]] + x = paddle.to_tensor(np_data) + y = x[::2, [1, 0], [2, 3], 0:4:2, [4, 6]] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_7(self): + # multiple adjacent and not adjacent int tensors (rank > 1d), with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]] + x = paddle.to_tensor(np_data) + y = x[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_8(self): + # multiple adjacent and not adjacent int tensors (rank > 1d), with int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[ + [[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]] + ] + x = paddle.to_tensor(np_data) + y = x[[[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]]] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_9(self): + # multiple int tensors, with broadcast. + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]] + x = paddle.to_tensor(np_data) + y = x[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_10(self): + # only one bool tensor with basic-index + np_data = np.random.randn(3, 4, 5, 6) + np_res = np_data[:, [True, False, True, False], 4] + + x = paddle.to_tensor(np_data) + y = x[:, [True, False, True, False], 4] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_11(self): + # only one bool tensor with all False + np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_res = np_data[:, [False, False, False, False], 4] + + x = paddle.to_tensor(np_data) + y = x[:, [False, False, False, False], 4] + + np.testing.assert_allclose(y.numpy(), np_res) + + +class TestGetitemInStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.exe = paddle.static.Executor() + + def test_combined_index_1(self): + # int tensor + slice (without decreasing axes) + np_data = np.random.randn(3, 4, 5, 6) + np_res = np_data[[0, 1], :, [1, 2]] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static(x, ([0, 1], slice(None, None, None), [1, 2])) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_2(self): + # int tensor + slice (with decreasing axes) + np_data = np.random.randn(3, 4, 5, 6) + np_res = np_data[:, 1, [1, 2], 0] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static(x, (slice(None, None, None), 1, [1, 2], 0)) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_3(self): + # multiple int tensors, with one int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[[1, 0], :, [1, 4], 1:5:2, 4] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, ([1, 0], slice(None, None, None), [1, 4], slice(1, 5, 2), 4) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_4(self): + # multiple not adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[:, [1, 0], 0:4:2, [2, 3], 4] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, (slice(None, None, None), [1, 0], slice(0, 4, 2), [2, 3], 4) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_5(self): + # multiple adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [1, 0], [2, 3], 0:4:2] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, (slice(None, None, 2), [1, 0], [2, 3], slice(0, 4, 2)) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_6(self): + # multiple adjacent and not adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [1, 0], [2, 3], 0:4:2, [4, 6]] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, + (slice(None, None, 2), [1, 0], [2, 3], slice(0, 4, 2), [4, 6]), + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_7(self): + # multiple adjacent and not adjacent int tensors (rank > 1d), with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, + ( + slice(None, None, 2), + [[1, 0]], + [[2, 3]], + slice(0, 4, 2), + [[4, 6]], + ), + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_8(self): + # multiple adjacent and not adjacent int tensors (rank > 1d), with int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[ + [[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]] + ] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, + ( + [[1, 0], [0, 1]], + [[2, 3], [1, 0]], + slice(0, 4, 2), + [[3, 5], [4, 2]], + ), + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_9(self): + # multiple int tensors, with broadcast. + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, ([[1, 0]], [1, 0], slice(0, 4, 2), [[3, 5], [4, 2]]) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_10(self): + # only one bool tensor with basic-index + np_data = np.random.randn(3, 4, 5, 6) + np_res = np_data[:, [True, False, True, False], 4] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, (slice(None, None, None), [True, False, True, False], 4) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_11(self): + # only one bool tensor with all False + np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_res = np_data[:, [False, False, False, False], 4] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, (slice(None, None, None), [False, False, False, False], 4) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + +class TestGetItemErrorCase(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_bool_shape_error1(self): + x = paddle.randn((4, 3, 2)) + with self.assertRaises(IndexError): + y = _getitem_static(x, ([True, False])) + + def test_bool_shape_error2(self): + x = paddle.randn((4, 3, 2)) + with self.assertRaises(IndexError): + y = _getitem_static(x, (1, paddle.to_tensor([True, False]), [0, 1])) diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py new file mode 100644 index 00000000000..f412247339f --- /dev/null +++ b/test/indexing/test_setitem.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.fluid.variable_index import _setitem_static + + +class TestSetitemInDygraph(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_combined_index_1(self): + np_data = np.zeros((3, 4, 5, 6), dtype='float32') + x = paddle.to_tensor(np_data) + + np_data[[0, 1], :, [1, 2]] = 10.0 + x[[0, 1], :, [1, 2]] = 10.0 + + np.testing.assert_allclose(x.numpy(), np_data) + + def test_combined_index_2(self): + np_data = np.ones((3, 4, 5, 6), dtype='float32') + x = paddle.to_tensor(np_data) + + np_data[:, 1, [1, 2], 0] = 10.0 + x[:, 1, [1, 2], 0] = 10.0 + + np.testing.assert_allclose(x.numpy(), np_data) + + def test_combined_index_3(self): + np_data = np.ones((3, 4, 5, 6), dtype='int32') + x = paddle.to_tensor(np_data) + + np_data[:, [True, False, True, False], [1, 4]] = 10 + x[:, [True, False, True, False], [1, 4]] = 10 + + np.testing.assert_allclose(x.numpy(), np_data) + + +class TestSetitemInStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.exe = paddle.static.Executor() + + def test_combined_index_1(self): + # int tensor + slice (without decreasing axes) + np_data = np.zeros((3, 4, 5, 6), dtype='float32') + np_data[[0, 1], :, [1, 2]] = 10.0 + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.zeros((3, 4, 5, 6), dtype='float32') + y = _setitem_static( + x, ([0, 1], slice(None, None, None), [1, 2]), 10.0 + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) + + def test_combined_index_2(self): + # int tensor + slice (with decreasing axes) + np_data = np.ones((3, 4, 5, 6), dtype='float32') + np_data[:, 1, [1, 2], 0] = 10.0 + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='float32') + y = _setitem_static( + x, (slice(None, None, None), 1, [1, 2], 0), 10.0 + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) + + def test_combined_index_3(self): + # int tensor + bool tensor + slice (without decreasing axes) + np_data = np.ones((3, 4, 5, 6), dtype='int32') + np_data[:, [True, False, True, False], [1, 4]] = 10 + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + y = _setitem_static( + x, + (slice(None, None, None), [True, False, True, False], [1, 4]), + 10, + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) + + def test_combined_index_4(self): + # int tensor (with ranks > 1) + bool tensor + slice (with decreasing axes) + np_data = np.ones((3, 4, 5, 6), dtype='int32') + np_data[[0, 0], [True, False, True, False], [[0, 2], [1, 4]], 4] = 16 + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + y = _setitem_static( + x, + ([0, 0], [True, False, True, False], [[0, 2], [1, 4]], 4), + 16, + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) + + def test_combined_index_5(self): + # int tensor + slice + Ellipsis + np_data = np.ones((3, 4, 5, 6), dtype='int32') + np_data[..., [1, 4, 3], ::2] = 5 + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + y = _setitem_static( + x, + (..., [1, 4, 3], slice(None, None, 2)), + 5, + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) diff --git a/test/legacy_test/test_set_value_op.py b/test/legacy_test/test_set_value_op.py index d57ef686dfa..9f797e6ab0a 100644 --- a/test/legacy_test/test_set_value_op.py +++ b/test/legacy_test/test_set_value_op.py @@ -1343,13 +1343,6 @@ class TestError(TestSetValueBase): x[::one] = self.value def _bool_list_error(self): - with self.assertRaises(TypeError): - x = paddle.ones(shape=self.shape, dtype=self.dtype) - if paddle.in_dynamic_mode(): - x[[True, False, 0]] = 0 - else: - x = paddle.static.setitem(x, [True, False, 0], 0) - with self.assertRaises(IndexError): x = paddle.ones(shape=self.shape, dtype=self.dtype) if paddle.in_dynamic_mode(): @@ -1380,7 +1373,6 @@ class TestError(TestSetValueBase): paddle.enable_static() with paddle.static.program_guard(self.program): self._value_type_error() - self._step_error() self._bool_list_error() self._bool_tensor_error() self._broadcast_mismatch() diff --git a/test/legacy_test/test_var_base.py b/test/legacy_test/test_var_base.py index b0d53788e57..99209ca798c 100644 --- a/test/legacy_test/test_var_base.py +++ b/test/legacy_test/test_var_base.py @@ -944,11 +944,9 @@ class TestVarBase(unittest.TestCase): var_tensor[var_tensor < 0.55], np_value[np_value < 0.55] ) - with self.assertRaises(ValueError): - var_tensor[[False, False, False, False]] - with self.assertRaises(ValueError): + with self.assertRaises(IndexError): var_tensor[[True, False]] - with self.assertRaises(ValueError): + with self.assertRaises(IndexError): var_tensor[[True, False, False, False, False]] with self.assertRaises(IndexError): var_tensor[paddle.to_tensor([[True, False, False, False]])] diff --git a/test/legacy_test/test_variable.py b/test/legacy_test/test_variable.py index f7338ce07e4..5774b08a328 100644 --- a/test/legacy_test/test_variable.py +++ b/test/legacy_test/test_variable.py @@ -257,10 +257,6 @@ class TestVariable(unittest.TestCase): self.assertTrue((result[2] == expected[2]).all()) self.assertTrue((result[3] == expected[3]).all()) - with self.assertRaises(IndexError): - one = paddle.ones(shape=[1]) - res = x[one, [0, 0]] - def _test_slice_index_list(self, place): data = np.random.rand(2, 3).astype("float32") prog = paddle.static.Program() @@ -323,9 +319,6 @@ class TestVariable(unittest.TestCase): self.assertTrue((result[5] == expected[5]).all()) self.assertTrue((result[6] == expected[6]).all()) - with self.assertRaises(IndexError): - res = x[[1.2, 0]] - def _test_slice_index_list_bool(self, place): data = np.random.rand(2, 3, 4).astype("float32") np_idx = np.array([[True, False, False], [True, False, True]]) @@ -375,9 +368,6 @@ class TestVariable(unittest.TestCase): with self.assertRaises(IndexError): res = x[[True, False, False]] - with self.assertRaises(ValueError): - with paddle.static.program_guard(prog): - res = x[[False, False]] def _test_slice_index_scalar_bool(self, place): data = np.random.rand(1, 3, 4).astype("float32") diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index d06b9f3e504..9ba690f5b1d 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -655,9 +655,9 @@ class TestApiWhileLoopSliceInBody(unittest.TestCase): startup_program = Program() with program_guard(main_program, startup_program): x = paddle.static.data(name='x', shape=[-1, 5], dtype='int32') - z = paddle.tensor.fill_constant([1], 'int32', 0) + z = paddle.tensor.fill_constant([], 'int32', 0) x_shape = paddle.shape(x) - i = paddle.tensor.fill_constant([1], 'int32', 0) + i = paddle.tensor.fill_constant([], 'int32', 0) z, _ = paddle.static.nn.while_loop(cond, body, [z, i]) place = ( -- GitLab