diff --git a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py index 49773ff8ad81aa0431ad7c690753c918bbb3b99e..3a44b1e48309211e19fddcc7c1e3a0b9688a0f13 100644 --- a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py +++ b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py @@ -15,6 +15,7 @@ """constexpr util""" +from functools import reduce import numpy as np from ...primitive import constexpr from ....common.tensor import Tensor @@ -23,26 +24,27 @@ from ...._extends.utils import Slice @constexpr def check_equal(param1, param2, msg="{},{}"): + """Checks whether the two parameters are equal or not.""" if param1 != param2: raise ValueError(msg.format(param1, param2)) return param1 @constexpr def check_tensor_setitem_index(index, element_type=None): - """Check tuple index type of tensor assignment.""" + """Checks tuple index type of tensor assignment.""" if index is None: raise ValueError("Tensor's index cannot be None.") # eg. Tensor[Slice] = u if isinstance(index, Slice): return True - # eg. Tensor[Tuple] = u + # eg. Tensor[tuple] = u if isinstance(index, tuple): if not index: raise ValueError("Tensor's index cannot be empty.") - # eg. Tensor[Tuple(Slice...)] = u - if not isinstance(index[0], Slice): - raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0]))) - return True + # eg. Tensor[tuple(Slice...)] = u + if isinstance(index[0], (Slice, int)): + return True + raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0]))) # eg. Tensor[Tensor[dtype=bool]] = u if index == mstype.tensor: if element_type is None or element_type != mstype.bool_: @@ -57,7 +59,7 @@ def check_tensor_setitem_index(index, element_type=None): @constexpr def is_same_type(inst, type_): """ - Check whether an object is an instance of a target type. + Checks whether an object is an instance of a target type. Inputs: inst (mindspore.dtype): Inspected type. @@ -69,34 +71,23 @@ def is_same_type(inst, type_): return inst == type_ -@constexpr -def error_msg(msg="", format_values=""): - """ - Used to throw exception information. - - Inputs: - msg (str): information content. - """ - - raise ValueError(msg.format(*format_values)) - def slice_expand(input_slices, shape): """ - Convert slice to indices. + Converts slice to indices. Inputs: - slices (List or Tuple(List, ...)): Slice tuple or slice. - shape (Tuple): The shape of a sensor is an integer element tuple. + slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. + shape (tuple): The shape of a sensor is an integer element tuple. Outputs: - (List, List, List), This is expressed as (begins, ends, strides). + tuple[list], This is expressed as (begins, ends, strides). """ begin = [] end = [] strides = [] index = 0 slices = None - # Slice or Tuple(Slice...) + # Slice or tuple(Slice...) if isinstance(input_slices, Slice): slices = (input_slices,) elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice): @@ -119,14 +110,15 @@ def slice_expand(input_slices, shape): index += 1 return begin, end, strides + @constexpr def slice2indices(input_slices, shape): """ - Convert slice to indices. + Converts slice to indices. Inputs: - slices (List or Tuple(List, ...)): Slice tuple or slice. - shape (Tuple): The shape of a sensor is an integer element tuple. + slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. + shape (tuple): The shape of a tensor is an integer element tuple. Outputs: Tensor, the shape is (n, 1). @@ -145,6 +137,7 @@ def slice2indices(input_slices, shape): @constexpr def check_indices(indices_size, index): + """Checks indices whether is empty.""" if indices_size < 1: raise ValueError("The tensor's index is unreasonable. index:{}".format(index)) return indices_size @@ -152,6 +145,7 @@ def check_indices(indices_size, index): @constexpr def check_indices_value_size(indices_size, value_size): + """Checks if the sizes are already matched.""" if value_size < 1: raise ValueError("The value assigned to tensor cannot be empty.") if value_size > 1: @@ -160,3 +154,30 @@ def check_indices_value_size(indices_size, value_size): "The value given to tensor does not match the index size. \ value size:{}, indics size:{}".format(value_size, indices_size)) return value_size + +@constexpr +def integer_to_indices(index, shape): + """Converts int or tuple[int] to indices.""" + size = reduce(lambda x, y: x * y, shape) + range_ = np.arange(size).reshape(shape) + value = range_[index] + value = value.reshape(-1, 1) + return Tensor(value, dtype=mstype.int32) + +@constexpr +def tuple_element_is_slice(indexs): + """Judges tuple element type.""" + if not indexs: + raise ValueError("Tensor's index cannot be empty.") + if isinstance(indexs, tuple) and isinstance(indexs[0], Slice): + return True + return False + +@constexpr +def tuple_element_is_int(indexs): + """Judges tuple element type.""" + if not indexs: + raise ValueError("Tensor's index cannot be empty.") + if isinstance(indexs, tuple) and isinstance(indexs[0], int): + return True + return False diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 742ee57166fadd0e2d450465a5d7a75846e71b08..13d4a1ffce6ac34bb020c0523b92f3e9814f5622 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -25,15 +25,14 @@ setitem = base.MultitypeFuncGraph('setitem') @setitem.register("List", "Number", "String") def _list_setitem_with_string(data, number_index, value): """ - Assign value to list. + Assigns value to list. Inputs: data (list): Data of type lis. number_index (Number): Index of data. - value (String): Value given. Outputs: - List, type is same as the element type of data. + list, type is same as the element type of data. """ return F.list_setitem(data, number_index, value) @@ -41,7 +40,7 @@ def _list_setitem_with_string(data, number_index, value): @setitem.register("List", "Number", "Number") def _list_setitem_with_number(data, number_index, value): """ - Assign value to list. + Assigns value to list. Inputs: data (list): Data of type lis. @@ -49,7 +48,7 @@ def _list_setitem_with_number(data, number_index, value): value (Number): Value given. Outputs: - List, type is same as the element type of data. + list, type is same as the element type of data. """ return F.list_setitem(data, number_index, value) @@ -57,7 +56,7 @@ def _list_setitem_with_number(data, number_index, value): @setitem.register("List", "Number", "Tensor") def _list_setitem_with_Tensor(data, number_index, value): """ - Assign value to list. + Assigns value to list. Inputs: data (list): Data of type lis. @@ -65,7 +64,7 @@ def _list_setitem_with_Tensor(data, number_index, value): value (Tensor): Value given. Outputs: - List, type is same as the element type of data. + list, type is same as the element type of data. """ return F.list_setitem(data, number_index, value) @@ -73,15 +72,15 @@ def _list_setitem_with_Tensor(data, number_index, value): @setitem.register("List", "Number", "List") def _list_setitem_with_List(data, number_index, value): """ - Assign value to list. + Assigns value to list. Inputs: data (list): Data of type lis. number_index (Number): Index of data. - value (List): Value given. + value (list): Value given. Outputs: - List, type is same as the element type of data. + list, type is same as the element type of data. """ return F.list_setitem(data, number_index, value) @@ -89,15 +88,15 @@ def _list_setitem_with_List(data, number_index, value): @setitem.register("Dictionary", "String", "Tensor") def _dict_setitem_with_tensor(data, key, value): """ - Assign value to dictionary. + Assigns value to dictionary. Inputs: - data (Dictionary): Data of type dict. + data (dict): Data of type dict. key (str): Key of the data. value (Tensor): Value given. Outputs: - Dict, type is as same as the element type of data. + dict, type is as same as the element type of data. """ return F.dict_setitem(data, key, value) @@ -105,15 +104,15 @@ def _dict_setitem_with_tensor(data, key, value): @setitem.register("Dictionary", "String", "Number") def _dict_setitem_with_number(data, key, value): """ - Assign value to dictionary. + Assigns value to dictionary. Inputs: - data (Dictionary): Data of type dict. + data (dict): Data of type dict. key (str): Key of the data. value (Number): Value given. Outputs: - Dict, type is as same as the element type of data. + dict, type is as same as the element type of data. """ return F.dict_setitem(data, key, value) @@ -219,14 +218,14 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value): Tensor assignment. Note: - Syntax support: A[Slice] = U + Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U Restraint condition: A is a Tensor Slice like "1:3, ::, :4:-1" U is a Tensor(size=1) or Tensor(size>1) Inputs: data (Tensor): Assigned tensor. - input_slice (Tuple(Slice)): Slice expression. + input_slice (Union[tuple[Slice], tuple[Number]]): Slice expression. value (Number): Assignment value. Outputs: @@ -236,39 +235,43 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value): def _tensor_assgin_tensor(data, input_slice, value): - """Given a tensor value assign to tensor by slice""" - # 1. condition + """Assigns a tensor value to the tensor by slice.""" result = None check_result = mult_util.check_tensor_setitem_index(input_slice) if check_result: data_shape = F.shape(data) - data_size = F.size(data) - data_dtype = F.dtype(data) indices = mult_util.slice2indices(input_slice, data_shape) - indices_size = F.size(indices) - indices_size = mult_util.check_indices(indices_size, input_slice) - update = F.fill(data_dtype, (indices_size,), 1) - condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition_1d = F.cast(condition_1d, mstype.bool_) - condition = F.reshape(condition_1d, data_shape) - # 2. u - value_fill = None - value_size = F.size(value) - - value_size = mult_util.check_indices_value_size(indices_size, value_size) - if value_size == 1: - value_fill = F.fill(data_dtype, (indices_size,), 1) - value = F.cast(value, data_dtype) - value_fill = F.tensor_mul(value_fill, value) - elif value_size > 1: - value_fill = F.reshape(value, (indices_size,)) - value_1d = F.scatter_nd(indices, value_fill, (data_size,)) - u = F.reshape(value_1d, data_shape) - # A[slice]= u -> A[B]=U -> select(B, U, A) - result = F.select(condition, u, data) + is_tuple_int = mult_util.tuple_element_is_int(input_slice) + if is_tuple_int: + indices = mult_util.integer_to_indices(input_slice, data_shape) + result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) return result +def _tensor_indices_tensor(data, data_shape, index, indices, value): + """Assigns a tensor value to the tensor.""" + data_size = F.size(data) + data_dtype = F.dtype(data) + indices_size = F.size(indices) + indices_size = mult_util.check_indices(indices_size, index) + update = F.fill(data_dtype, (indices_size,), 1) + condition_1d = F.scatter_nd(indices, update, (data_size,)) + condition_1d = F.cast(condition_1d, mstype.bool_) + condition = F.reshape(condition_1d, data_shape) + value_fill = None + value_size = F.size(value) + + value_size = mult_util.check_indices_value_size(indices_size, value_size) + if value_size == 1: + value_fill = F.fill(data_dtype, (indices_size,), 1) + value = F.cast(value, data_dtype) + value_fill = F.tensor_mul(value_fill, value) + elif value_size > 1: + value_fill = F.reshape(value, (indices_size,)) + value_1d = F.scatter_nd(indices, value_fill, (data_size,)) + u = F.reshape(value_1d, data_shape) + return F.select(condition, u, data) + @setitem.register("Tensor", "Slice", "Number") def _tensor_setitem_with_slice_v1(data, input_slice, value): """ @@ -297,14 +300,14 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value): Tensor assignment. Note: - Syntax support: A[Slice] = u + Syntax support: A[tuple(Slice)] = u, and A[tuple(Number)] = u Restraint condition: A is a Tensor. Slice like "1:3, ::, :4:-1" u is a scalar Inputs: data (Tensor): Assigned tensor. - input_slice (Tuple(Slice)): slice expression. + input_slice (Union[tuple[Slice], tuple[Number]]): slice expression. value (Number): Assignment value. Outputs: @@ -314,25 +317,46 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value): def _tensor_assgin_number(data, input_slice, value): - """Given a scalar assign to tensor by slice""" - # 1. condition + """Givens a scalar assign to tensor by slice""" check_result = mult_util.check_tensor_setitem_index(input_slice) result = None if check_result: data_shape = F.shape(data) - data_size = F.size(data) - data_dtype = F.dtype(data) indices = mult_util.slice2indices(input_slice, data_shape) - indices_size = F.size(indices) - indices_size = mult_util.check_indices(indices_size, input_slice) - update = F.fill(data_dtype, (indices_size,), 1) - condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition_1d = F.cast(condition_1d, mstype.bool_) - condition = F.reshape(condition_1d, data_shape) - # 2. u - value_fill = F.fill(data_dtype, (indices_size,), value) - value_1d = F.scatter_nd(indices, value_fill, (data_size,)) - u = F.reshape(value_1d, data_shape) - # A[slice]= u -> A[B]=U -> select(B, U, A) - result = F.select(condition, u, data) + is_tuple_int = mult_util.tuple_element_is_int(input_slice) + if is_tuple_int: + indices = mult_util.integer_to_indices(input_slice, data_shape) + result = _tensor_indices_number(data, data_shape, input_slice, indices, value) return result + + +def _tensor_indices_number(data, data_shape, index, indices, value): + """Assigns a scalar value to the tensor.""" + data_size = F.size(data) + data_dtype = F.dtype(data) + indices_size = F.size(indices) + indices_size = mult_util.check_indices(indices_size, index) + update = F.fill(data_dtype, (indices_size,), 1) + condition_1d = F.scatter_nd(indices, update, (data_size,)) + condition_1d = F.cast(condition_1d, mstype.bool_) + condition = F.reshape(condition_1d, data_shape) + value_fill = F.fill(data_dtype, (indices_size,), value) + value_1d = F.scatter_nd(indices, value_fill, (data_size,)) + u = F.reshape(value_1d, data_shape) + return F.select(condition, u, data) + + +@setitem.register("Tensor", "Number", "Number") +def _tensor_setitem_with_int_v1(data, index, value): + """Syntax: A[1] = 3""" + data_shape = F.shape(data) + indices = mult_util.integer_to_indices(index, data_shape) + return _tensor_indices_number(data, data_shape, index, indices, value) + + +@setitem.register("Tensor", "Number", "Tensor") +def _tensor_setitem_with_int_v2(data, index, value): + """Syntax: A[1] = Tensor""" + data_shape = F.shape(data) + indices = mult_util.integer_to_indices(index, data_shape) + return _tensor_indices_tensor(data, data_shape, index, indices, value) diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index b547fdc06b9524d8186ea40a3b2de8e0d2a110b7..58af9bc273f3702e27137fb0487b929ac26e24f9 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -139,7 +139,7 @@ class TensorAssignWithSlice(Cell): z = a return z -def test_tensor_assign_with_slice(): +def test_tensor_assign(): context.set_context(mode=context.GRAPH_MODE, save_graphs=True) net = TensorAssignWithSlice() net2= TensorAssignWithSlice2() @@ -148,6 +148,7 @@ def test_tensor_assign_with_slice(): a = np.arange(60).reshape(3,4,5) b = Tensor([1]) Ta = Tensor(a) + Ta4d = Tensor(a.reshape(1,3,4,5)) Tb= Tensor([1,3]) Tc= Tensor([]) t = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) @@ -185,6 +186,47 @@ def test_tensor_assign_with_slice(): with pytest.raises(ValueError): net_e1(Ta, 2) + net = TensorAssignWithInteger() + # Error for A[Number] = scalar/Tensor + # 1. A[Number] = U, U is a Tensor, u.size not match + with pytest.raises(ValueError): + net(Ta, Tb) + with pytest.raises(ValueError): + net(Ta, Tc) + # 2. A[Number] = U, the number index error + with pytest.raises(IndexError): + net(Ta4d, b) + + # Error for A[(n,m)] = scalar/Tensor + # 1. A[(n,m)] = U, U is a tensor. u.size not match + net = TensorAssignWithTupleInteger() + with pytest.raises(ValueError): + net(Ta, Tc) + with pytest.raises(ValueError): + net(Ta, Tb) + # 2. A[(n,m)] = U, the number index error + with pytest.raises(IndexError): + net(Ta4d, b) + +class TensorAssignWithInteger(Cell): + def __init__(self): + super(TensorAssignWithInteger, self).__init__() + + def construct(self, a, b): + a[1] = 1 + a[0] = b + return a + +class TensorAssignWithTupleInteger(Cell): + def __init__(self): + super(TensorAssignWithTupleInteger, self).__init__() + + def construct(self, a, b): + a[(1)] = 1 + a[(1)] = b + a[(1,1)] = b + a[(1,1)] = 1 + return a class TensorAssignWithBoolTensorIndex(Cell): def __init__(self): @@ -274,6 +316,14 @@ def test_tensor_assign_bool_index(): net4(Ta, u_scalar) test_cases = [ + ('TensorAssignWithTupleInteger', { + 'block': TensorAssignWithTupleInteger(), + 'desc_inputs': [Ta, u_tensor], + }), + ('TensorAssignWithInteger', { + 'block': TensorAssignWithInteger(), + 'desc_inputs': [Ta, u_tensor], + }), ('TensorAssignWithSlice', { 'block': TensorAssignWithSlice(), 'desc_inputs': [Ta, u_tensor],