# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import numpy as np from . import unique_name from . import core import paddle MAX_INTEGER = 2**31 - 1 def is_list_tuple(index, contain_type): def _is_list_tuple(item): if not (isinstance(item, (list, tuple)) or type(item) == contain_type): return False if isinstance(item, (tuple, list)): for s in item: if not _is_list_tuple(s): return False return True if not isinstance(index, (tuple, list)): return False for s in index: if not _is_list_tuple(s): return False return True def is_one_dim_list(index, contain_type): if isinstance(index, list): for i in index: if not isinstance(i, contain_type): return False else: return False return True def get_list_index_shape(var_dims, index_dims): var_dims_size = len(var_dims) index_dims_size = len(index_dims) out_dims_size = var_dims_size - index_dims[0] + index_dims_size - 1 out_dims_shape = [1] * out_dims_size out_dims_shape[: index_dims_size - 1] = index_dims[1:] out_dims_shape[index_dims_size - 1 :] = var_dims[index_dims[0] :] return out_dims_shape class SliceInfo: def __init__(self): self.pre_shape = None self.indexes = [] self.dtype = None def update(self, index): if is_list_tuple(index, int) or isinstance( index, (paddle.fluid.Variable, np.ndarray) ): # convert index to Tensor if not isinstance(index, paddle.fluid.Variable): index = paddle.assign(index) if self.dtype is None: self.dtype = index.dtype else: if index.dtype != self.dtype: raise IndexError( "Data type of Tensor/List index should be same. The current data type is {}, but the previous data type is {}.".format( index.dtype, self.dtype ) ) self.indexes.append(index) if self.pre_shape is None: self.pre_shape = index.shape else: if self.pre_shape != index.shape: # broadcast cur_shape = paddle.broadcast_shape( self.pre_shape, index.shape ) for i in range(len(self.indexes)): self.indexes[i] = paddle.broadcast_to( self.indexes[i], cur_shape ) self.pre_shape = self.indexes[-1].shape else: raise ValueError( "Index should be list/tuple of int or Tensor, but received {}.".format( index ) ) def shape_stride(self, shape): s = [1] * len(shape) for i in range(len(shape) - 2, -1, -1): s[i] = shape[i + 1] * s[i + 1] return s def numel(self, shape): return reduce(lambda x, y: x * y, shape) def get_offset_stride(self, tensor_shape): for index in self.indexes: if not isinstance(index, paddle.fluid.Variable): raise ValueError( "only support list/tensor index, but received {}.".format( type(index) ) ) if len(self.indexes) <= len(tensor_shape) or len(self.indexes) == 1: shape = paddle.stack(self.indexes) axes = list(range(1, len(self.pre_shape) + 1)) + [ 0, ] else: raise ValueError( "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( len(tensor_shape), self.pre_shape[0] ) ) shape_transpose = paddle.transpose(shape, axes) return shape_transpose def get_item(self, tensor): shape_transpose = self.get_offset_stride(tensor.shape) index = paddle.assign(shape_transpose) return paddle.gather_nd(tensor, index) def set_item(self, tensor_origin, value): if not isinstance(value, paddle.fluid.Variable): value = paddle.assign(value) tensor_type = None if tensor_origin.dtype in [ core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64, ]: tensor = tensor_origin else: tensor_type = tensor_origin.dtype tensor = tensor_origin.astype(core.VarDesc.VarType.FP32) if value.dtype != tensor.dtype: value = value.astype(tensor.dtype) shape_transpose = self.get_offset_stride(tensor_origin.shape) index = paddle.assign(shape_transpose) gather_tensor_shape = get_list_index_shape( tensor.shape, [ len(self.indexes), ] + list(self.indexes[-1].shape), ) value_dims_bd = [ 1, ] * len(gather_tensor_shape) value_dims_bd[-len(value.shape) :] = list(value.shape) for i in range(len(gather_tensor_shape)): if not ( value_dims_bd[i] == gather_tensor_shape[i] or value_dims_bd[i] == 1 ): raise ValueError( "{} can not broadcast into {}".format( value.shape, gather_tensor_shape ) ) value_broadcast = paddle.broadcast_to(value, gather_tensor_shape) value_1d = value_broadcast.reshape( [-1] + gather_tensor_shape[len(index.shape) - 1 :] ) index_1d = index.reshape([-1, index.shape[-1]]) tensor_stride = paddle.assign( self.shape_stride(tensor.shape[: index.shape[-1]]) ) inds = [] for i in range(index_1d.shape[0]): temp = (index_1d[i] * tensor_stride).sum() inds.append(temp) index_1d = paddle.stack(inds).reshape([-1]) t_reshape = tensor.reshape([-1] + list(tensor.shape[index.shape[-1] :])) out = paddle.scatter(t_reshape, index_1d, value_1d) if tensor_type is not None: out = out.astype(tensor_type) tensor_origin[:] = out.reshape(tensor_origin.shape) return tensor_origin def replace_ellipsis(var, item): from .framework import Variable # Use slice(None) to replace Ellipsis. # For var, var.shape = [3,4,5,6] # # var[..., 1:2] -> var[:, :, :, 1:2] # var[0, ...] -> var[0] # var[0, ..., 1:2] -> var[0, :, :, 1:2] item = list(item) # Remove Variable to skip bug when counting Ellipsis item_remove_var = [ ele for ele in item if not isinstance(ele, (Variable, np.ndarray)) and ele is not None ] ell_count = item_remove_var.count(Ellipsis) if ell_count == 0: return item elif ell_count > 1: raise IndexError("An index can only have a single ellipsis ('...')") ell_idx = item.index(Ellipsis) if ell_idx == len(item) - 1: return item[:-1] else: item[ell_idx : ell_idx + 1] = [slice(None)] * ( len(var.shape) - len(item) + item.count(None) + 1 ) return item def replace_ndarray(item): new_item = [] for slice_item in item: if isinstance(slice_item, np.ndarray): new_item.append(paddle.assign(slice_item)) else: new_item.append(slice_item) return new_item def replace_none(item): new_item = [] none_axes = [] for i, slice_item in enumerate(item): if slice_item is None: none_axes.append(i) else: new_item.append(slice_item) return new_item, none_axes def is_integer_or_scalar_tensor(ele): from .framework import Variable if isinstance(ele, int): return True elif isinstance(ele, Variable): if len(ele.shape) == 1 and ele.shape[0] == 1: return True return False def is_bool_tensor(ele): from .framework import Variable if isinstance(ele, Variable) and ele.dtype == paddle.bool: return True return False def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags): from .framework import Variable if paddle.utils._contain_var(attr): inputs[tensor_attr_name] = paddle.utils._convert_to_tensor_list( attr, dtype="int64" ) for i, dim in enumerate(attr): if isinstance(dim, Variable): attrs[attr_name].append(-1) infer_flags[i] = -1 else: attrs[attr_name].append(dim) else: attrs[attr_name] = attr # the item is a tensor of bool def get_value_for_bool_tensor(var, item): if len(item.shape) > len(var.shape): raise IndexError( "The dims of bool index doesn't match indexed array, " "the dims of bool index except to be equal or less " "than {}, but received {}.".format(len(var.shape), len(item.shape)) ) i = 0 item_shape = item.shape while i < len(item.shape): dim_len = item_shape[i] if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]: raise IndexError( "The dimension of bool index doesn't match indexed array along " "dimension {}, the target dimension is {}, but received {}.".format( i, var.shape[i], dim_len ) ) i += 1 empty_shape = [0] + list(var.shape[i:]) def idx_not_empty(var, item): from ..tensor import gather_nd bool_2_idx = paddle.nonzero(item == True) return gather_nd(var, bool_2_idx) from paddle.static.nn import cond return cond( item.any(), lambda: idx_not_empty(var, item), lambda: paddle.empty(empty_shape, var.dtype), ) def _getitem_impl_(var, item): """ Slice the variable. Args: item(int/slice/tuple) : the index. Returns: Sliced variable """ from .framework import default_main_program, Variable if isinstance(item, list): if not is_one_dim_list(item, int): item = tuple(item) if not isinstance(item, tuple): item = (item,) decrease_axes = [] axes = [] starts = [] ends = [] steps = [] reverse_axes = [] use_strided_slice = False item = replace_ndarray(item) item = replace_ellipsis(var, item) item, none_axes = replace_none(item) slice_info = SliceInfo() is_tensor_array = ( hasattr(var, "desc") and var.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY ) for dim, slice_item in enumerate(item): if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor( slice_item ): if ( not is_tensor_array and isinstance(slice_item, int) and var.shape[dim] is not None and var.shape[dim] >= 0 and slice_item >= var.shape[dim] ): # For python, if users write a, b = var, the __getitem__ # method will iterate through 0, 1, 2 ... until __getitem__ # throws an IndexError, then stop. The var[0], var[1] will # be given to a, b respectively. If more values are given, # the unpack size would cause error. # We raises IndexError here to support grammar like `a, b = var` raise IndexError( "slice_item %d at dim %d should be >= 0 and < var.shape[%d]: %d" % (slice_item, dim, dim, var.shape[dim]) ) decrease_axes.append(dim) start = slice_item step = 1 end = slice_item + 1 if slice_item != -1 else MAX_INTEGER elif isinstance(slice_item, slice): start = slice_item.start end = slice_item.stop step = slice_item.step if start is None and end is None and step is None: continue step = 1 if step is None else step if start is None: start = 0 if step > 0 else MAX_INTEGER if end is None: if ( paddle.fluid.framework._non_static_mode() or not is_tensor_array ) and var.shape[dim] != -1: end = var.shape[dim] if step > 0 else -1 else: end = MAX_INTEGER if step > 0 else -1 elif isinstance(slice_item, list): all_bool = True if is_list_tuple(slice_item, int): slice_info.update(slice_item) continue for i in slice_item: if type(i) is int: all_bool = False elif not isinstance(i, bool): raise TypeError("Only support int or bool in index list.") if len(item) != 1: raise IndexError( "When index contains a list, its length must be 1, but received {}.".format( len(item) ) ) new_slice_item = [] if all_bool: if len(slice_item) != var.shape[0]: raise IndexError( "The dimension of bool index doesn't match indexed array along " "dimension 0, the target dimension is {}, but received {}.".format( var.shape[0], len(slice_item) ) ) for idx, ele in enumerate(slice_item): if ele is True: new_slice_item.append(idx) slice_item = new_slice_item else: for idx, ele in enumerate(slice_item): if type(ele) is int: new_slice_item.append(ele) elif ele is True: new_slice_item.append(1) else: new_slice_item.append(0) slice_item = new_slice_item from ..tensor import index_select idx = paddle.assign(np.array(slice_item).astype("int32")) return index_select(var, index=idx, axis=0) elif isinstance(slice_item, (Variable, core.eager.Tensor)): if len(item) == 1: from ..tensor import index_select if slice_item.dtype == paddle.bool: return get_value_for_bool_tensor(var, slice_item) else: if len(slice_item.shape) == 1: return index_select(var, index=slice_item, axis=0) else: slice_info.update(slice_item) continue else: slice_info.update(slice_item) continue else: raise IndexError( "Valid index accept int or slice or ellipsis or list, but received {}.".format( slice_item ) ) axes.append(dim) starts.append(start) ends.append(end) steps.append(step) use_strided_slice = True if step != 1 else use_strided_slice if slice_info.indexes: if len(slice_info.indexes) != len(item): raise IndexError( "Valid index accept int or slice or ellipsis or list, but received {}.".format( item ) ) return slice_info.get_item(var) inputs = {'Input': [var]} attrs = { 'axes': axes, 'starts': [], 'ends': [], 'decrease_axis': decrease_axes, } if use_strided_slice: attrs['strides'] = [] infer_flags = [1] * len(axes) deal_attrs(attrs, starts, "starts", "StartsTensorList", inputs, infer_flags) deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags) deal_attrs( attrs, steps, "strides", "StridesTensorList", inputs, infer_flags ) attrs['infer_flags'] = infer_flags out = var if len(axes) > 0: op_type = "strided_slice" if use_strided_slice else "slice" if paddle.fluid.framework.in_dygraph_mode() and op_type == "slice": if "StartsTensorList" in inputs.keys(): st = inputs['StartsTensorList'] else: st = attrs['starts'] if "EndsTensorList" in inputs.keys(): end = inputs['EndsTensorList'] else: end = attrs['ends'] out = paddle._C_ops.slice( var, axes, st, end, attrs['infer_flags'], attrs['decrease_axis'] ) else: target_block = default_main_program().current_block() slice_out_var = target_block.create_var( name=unique_name.generate_with_ignorable_key( var.name + "_" + op_type ), dtype=var.dtype, ) target_block.append_op( type=op_type, inputs=inputs, outputs={'Out': [slice_out_var]}, attrs=attrs, ) out = slice_out_var if len(reverse_axes) > 0: from .layers.tensor import reverse out = reverse(out, axis=reverse_axes) # Deal with cases when all axes are decreased. # After slice, the shape of out is [1], which should have been [], but Paddle doesn't support scalar. # In order to ensure the correctness of the final shape of out, one dimension of out needs to be decreased. # For example: # # x.shape: (2,3,4) # out = x[0, 1, 1, None] # out.shape : (1) if len(decrease_axes) == len(var.shape): none_axes = none_axes[1:] if len(none_axes) > 0: # Deal with cases that decrease_axes is not empty # For example: # # x.shape: (2,3,4) # out = x[0, 0:2, None] # out.shape : (2, 1, 4) for idx, axis in enumerate(none_axes): l = len([i for i in decrease_axes if i < axis]) new_axis = axis - l none_axes[idx] = new_axis # Deal with cases when all axes are decreased. # After slice, the shape of out is [1], which should have been [], but Paddle doesn't support scalar. # In order to ensure the correctness of the final shape of out, one dimension of out needs to be decreased. # For example: # # x.shape: (2,3,4) # out = x[0, 1, 1, None] # out.shape : (1) from ..tensor import unsqueeze out = unsqueeze(out, axis=none_axes) return out def _setitem_for_tensor_array(var, item, value): """branches for tensor array setitem operation. A item can be a: (1) int/Variable, which is a simple number/variable such as [1], [-2] (2) Slice, which is represented by bounds such as [2:-1] (3) Tuple, which includes the above two cases such as [2:-1, 1] If item is case (1), we perform paddle.tensor.array_write, in other cases, we raise a NotImplementedError. """ from ..framework import LayerHelper, core, _non_static_mode from .framework import Variable assert ( not _non_static_mode() ), "setitem for tensor_array must be called in static graph mode." if isinstance(item, (Variable, int)): from paddle.jit.dy2static.variable_trans_func import ( to_static_variable, ) from paddle import cast from paddle.tensor import array_write item = paddle.cast(to_static_variable(item), dtype='int64') value = to_static_variable(value) array_write(x=value, i=item, array=var) else: raise NotImplementedError( "Only support __setitem__ by Int/Variable in tensor_array, but gets {}".format( type(item) ) ) def _setitem_impl_(var, item, value): from .framework import default_main_program, Variable from paddle.fluid import core if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: return _setitem_for_tensor_array(var, item, value) inputs = {'Input': var} if isinstance(item, list): if not is_one_dim_list(item, int): item = tuple(item) # 1. Parse item if not isinstance(item, tuple): item = (item,) decrease_axes = [] axes = [] starts = [] ends = [] steps = [] item = replace_ndarray(item) item = replace_ellipsis(var, item) item, none_axes = replace_none(item) slice_info = SliceInfo() dim = 0 for _, slice_item in enumerate(item): if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor( slice_item ): decrease_axes.append(dim) start = slice_item end = slice_item + 1 if slice_item != -1 else MAX_INTEGER step = 1 elif isinstance(slice_item, slice): start = slice_item.start end = slice_item.stop step = slice_item.step if start is None and end is None and step is None: dim += 1 continue step = 1 if step is None else step if not isinstance(step, Variable) and step == 0: raise ValueError( "When assign a value to a paddle.Tensor, step can not be 0, " "but received step is {}.".format(step) ) if isinstance(step, Variable) and (start is None or end is None): raise ValueError( "When assign a value to a paddle.Tensor, it's not supported that " "the start or end is None when the type of step is paddle.Tensor." ) if start is None: start = 0 if step > 0 else MAX_INTEGER if end is None: end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER) elif isinstance(slice_item, list): if is_list_tuple(slice_item, int): slice_info.update(slice_item) continue for i in slice_item: if not isinstance(i, bool): raise TypeError( "Doesn't support {} in index list.".format(type(i)) ) if len(item) != 1: raise IndexError( "When index contains a bool list, its length must be 1, but received {}.".format( len(item) ) ) idx_tensor = paddle.assign(slice_item) return set_value_for_bool_tensor(var, idx_tensor, value) elif isinstance(slice_item, Variable): if slice_item.dtype == core.VarDesc.VarType.BOOL: if len(item) != 1: raise IndexError( "When index contains a bool tensor, its length must be 1, but received {}.".format( len(item) ) ) return set_value_for_bool_tensor(var, slice_item, value) else: slice_info.update(slice_item) continue else: raise IndexError( "Valid index accept int, slice, ellipsis, None, list of bool, Variable, " "but received {}.".format(slice_item) ) axes.append(dim) starts.append(start) ends.append(end) steps.append(step) dim += 1 if slice_info.indexes: if len(slice_info.indexes) != len(item): raise IndexError( "Valid index accept int or slice or ellipsis or list, but received {}.".format( item ) ) return slice_info.set_item(var, value) attrs = { 'axes': axes, 'starts': starts, 'ends': ends, 'steps': steps, 'decrease_axes': decrease_axes, 'none_axes': none_axes, } if paddle.utils._contain_var(starts): inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list( starts ) del attrs['starts'] if paddle.utils._contain_var(ends): inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends) del attrs['ends'] if paddle.utils._contain_var(steps): inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps) del attrs['steps'] # 2. Parse value dtype = var.dtype attrs['dtype'] = dtype from .data_feeder import convert_dtype # 2.1 value is an integer, float or complex if isinstance(value, (bool, int, float, complex)): value = np.array([value]).astype(convert_dtype(dtype)) # 2.2 value is a np.ndarray if isinstance(value, np.ndarray): shape = list(value.shape) values = value.ravel().tolist() attrs["values"] = values attrs["shape"] = shape elif isinstance(value, (Variable, core.eager.Tensor)): inputs["ValueTensor"] = value else: raise TypeError( "Only support to assign an integer, float, numpy.ndarray or " "paddle.Tensor to a paddle.Tensor, but received {}".format( type(value) ) ) if paddle.fluid.framework._non_static_mode(): var._bump_inplace_version() cur_block = default_main_program().current_block() cur_block.append_op( type="set_value", inputs=inputs, outputs={'Out': var}, attrs=attrs, inplace_map={"Input": "Out"}, ) return var # the item is a tensor of bool def set_value_for_bool_tensor(var, item, value): if len(item.shape) > len(var.shape): raise IndexError( "The dims of bool index doesn't match indexed array, " "the dims of bool index except to be equal or less " "than {}, but received {}.".format(len(var.shape), len(item.shape)) ) for i, dim_len in enumerate(item.shape): if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]: raise IndexError( "The dimension of bool index doesn't match indexed array along " "dimension {}, the target dimension is {}, but received {}.".format( i, var.shape[i], dim_len ) ) def idx_not_empty(var, item, value): from .framework import Variable from ..tensor import gather_nd, scatter_nd_add if not isinstance(value, Variable): value = paddle.assign(value).cast(var.dtype) idx = paddle.nonzero(item) gather_val = gather_nd(var, idx) gather_val_new = value - gather_val out = scatter_nd_add(var, idx, gather_val_new) var[:] = out from paddle.static.nn import cond # If all the bool index is False, just do nothing cond(item.any(), lambda: idx_not_empty(var, item, value)) return var