提交 31aae361 编写于 作者: C candanzg

Tensor assign with integer

Signed-off-by: Ncandanzg <zhangshucheng@huawei.com>
上级 496ffff3
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""constexpr util""" """constexpr util"""
from functools import reduce
import numpy as np import numpy as np
from ...primitive import constexpr from ...primitive import constexpr
from ....common.tensor import Tensor from ....common.tensor import Tensor
...@@ -23,26 +24,27 @@ from ...._extends.utils import Slice ...@@ -23,26 +24,27 @@ from ...._extends.utils import Slice
@constexpr @constexpr
def check_equal(param1, param2, msg="{},{}"): def check_equal(param1, param2, msg="{},{}"):
"""Checks whether the two parameters are equal or not."""
if param1 != param2: if param1 != param2:
raise ValueError(msg.format(param1, param2)) raise ValueError(msg.format(param1, param2))
return param1 return param1
@constexpr @constexpr
def check_tensor_setitem_index(index, element_type=None): def check_tensor_setitem_index(index, element_type=None):
"""Check tuple index type of tensor assignment.""" """Checks tuple index type of tensor assignment."""
if index is None: if index is None:
raise ValueError("Tensor's index cannot be None.") raise ValueError("Tensor's index cannot be None.")
# eg. Tensor[Slice] = u # eg. Tensor[Slice] = u
if isinstance(index, Slice): if isinstance(index, Slice):
return True return True
# eg. Tensor[Tuple] = u # eg. Tensor[tuple] = u
if isinstance(index, tuple): if isinstance(index, tuple):
if not index: if not index:
raise ValueError("Tensor's index cannot be empty.") raise ValueError("Tensor's index cannot be empty.")
# eg. Tensor[Tuple(Slice...)] = u # eg. Tensor[tuple(Slice...)] = u
if not isinstance(index[0], Slice): if isinstance(index[0], (Slice, int)):
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
return True return True
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
# eg. Tensor[Tensor[dtype=bool]] = u # eg. Tensor[Tensor[dtype=bool]] = u
if index == mstype.tensor: if index == mstype.tensor:
if element_type is None or element_type != mstype.bool_: if element_type is None or element_type != mstype.bool_:
...@@ -57,7 +59,7 @@ def check_tensor_setitem_index(index, element_type=None): ...@@ -57,7 +59,7 @@ def check_tensor_setitem_index(index, element_type=None):
@constexpr @constexpr
def is_same_type(inst, type_): def is_same_type(inst, type_):
""" """
Check whether an object is an instance of a target type. Checks whether an object is an instance of a target type.
Inputs: Inputs:
inst (mindspore.dtype): Inspected type. inst (mindspore.dtype): Inspected type.
...@@ -69,34 +71,23 @@ def is_same_type(inst, type_): ...@@ -69,34 +71,23 @@ def is_same_type(inst, type_):
return inst == type_ return inst == type_
@constexpr
def error_msg(msg="", format_values=""):
"""
Used to throw exception information.
Inputs:
msg (str): information content.
"""
raise ValueError(msg.format(*format_values))
def slice_expand(input_slices, shape): def slice_expand(input_slices, shape):
""" """
Convert slice to indices. Converts slice to indices.
Inputs: Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice. slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple. shape (tuple): The shape of a sensor is an integer element tuple.
Outputs: Outputs:
(List, List, List), This is expressed as (begins, ends, strides). tuple[list], This is expressed as (begins, ends, strides).
""" """
begin = [] begin = []
end = [] end = []
strides = [] strides = []
index = 0 index = 0
slices = None slices = None
# Slice or Tuple(Slice...) # Slice or tuple(Slice...)
if isinstance(input_slices, Slice): if isinstance(input_slices, Slice):
slices = (input_slices,) slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice): elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice):
...@@ -119,14 +110,15 @@ def slice_expand(input_slices, shape): ...@@ -119,14 +110,15 @@ def slice_expand(input_slices, shape):
index += 1 index += 1
return begin, end, strides return begin, end, strides
@constexpr @constexpr
def slice2indices(input_slices, shape): def slice2indices(input_slices, shape):
""" """
Convert slice to indices. Converts slice to indices.
Inputs: Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice. slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple. shape (tuple): The shape of a tensor is an integer element tuple.
Outputs: Outputs:
Tensor, the shape is (n, 1). Tensor, the shape is (n, 1).
...@@ -145,6 +137,7 @@ def slice2indices(input_slices, shape): ...@@ -145,6 +137,7 @@ def slice2indices(input_slices, shape):
@constexpr @constexpr
def check_indices(indices_size, index): def check_indices(indices_size, index):
"""Checks indices whether is empty."""
if indices_size < 1: if indices_size < 1:
raise ValueError("The tensor's index is unreasonable. index:{}".format(index)) raise ValueError("The tensor's index is unreasonable. index:{}".format(index))
return indices_size return indices_size
...@@ -152,6 +145,7 @@ def check_indices(indices_size, index): ...@@ -152,6 +145,7 @@ def check_indices(indices_size, index):
@constexpr @constexpr
def check_indices_value_size(indices_size, value_size): def check_indices_value_size(indices_size, value_size):
"""Checks if the sizes are already matched."""
if value_size < 1: if value_size < 1:
raise ValueError("The value assigned to tensor cannot be empty.") raise ValueError("The value assigned to tensor cannot be empty.")
if value_size > 1: if value_size > 1:
...@@ -160,3 +154,30 @@ def check_indices_value_size(indices_size, value_size): ...@@ -160,3 +154,30 @@ def check_indices_value_size(indices_size, value_size):
"The value given to tensor does not match the index size. \ "The value given to tensor does not match the index size. \
value size:{}, indics size:{}".format(value_size, indices_size)) value size:{}, indics size:{}".format(value_size, indices_size))
return value_size return value_size
@constexpr
def integer_to_indices(index, shape):
"""Converts int or tuple[int] to indices."""
size = reduce(lambda x, y: x * y, shape)
range_ = np.arange(size).reshape(shape)
value = range_[index]
value = value.reshape(-1, 1)
return Tensor(value, dtype=mstype.int32)
@constexpr
def tuple_element_is_slice(indexs):
"""Judges tuple element type."""
if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], Slice):
return True
return False
@constexpr
def tuple_element_is_int(indexs):
"""Judges tuple element type."""
if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], int):
return True
return False
...@@ -25,15 +25,14 @@ setitem = base.MultitypeFuncGraph('setitem') ...@@ -25,15 +25,14 @@ setitem = base.MultitypeFuncGraph('setitem')
@setitem.register("List", "Number", "String") @setitem.register("List", "Number", "String")
def _list_setitem_with_string(data, number_index, value): def _list_setitem_with_string(data, number_index, value):
""" """
Assign value to list. Assigns value to list.
Inputs: Inputs:
data (list): Data of type lis. data (list): Data of type lis.
number_index (Number): Index of data. number_index (Number): Index of data.
value (String): Value given.
Outputs: Outputs:
List, type is same as the element type of data. list, type is same as the element type of data.
""" """
return F.list_setitem(data, number_index, value) return F.list_setitem(data, number_index, value)
...@@ -41,7 +40,7 @@ def _list_setitem_with_string(data, number_index, value): ...@@ -41,7 +40,7 @@ def _list_setitem_with_string(data, number_index, value):
@setitem.register("List", "Number", "Number") @setitem.register("List", "Number", "Number")
def _list_setitem_with_number(data, number_index, value): def _list_setitem_with_number(data, number_index, value):
""" """
Assign value to list. Assigns value to list.
Inputs: Inputs:
data (list): Data of type lis. data (list): Data of type lis.
...@@ -49,7 +48,7 @@ def _list_setitem_with_number(data, number_index, value): ...@@ -49,7 +48,7 @@ def _list_setitem_with_number(data, number_index, value):
value (Number): Value given. value (Number): Value given.
Outputs: Outputs:
List, type is same as the element type of data. list, type is same as the element type of data.
""" """
return F.list_setitem(data, number_index, value) return F.list_setitem(data, number_index, value)
...@@ -57,7 +56,7 @@ def _list_setitem_with_number(data, number_index, value): ...@@ -57,7 +56,7 @@ def _list_setitem_with_number(data, number_index, value):
@setitem.register("List", "Number", "Tensor") @setitem.register("List", "Number", "Tensor")
def _list_setitem_with_Tensor(data, number_index, value): def _list_setitem_with_Tensor(data, number_index, value):
""" """
Assign value to list. Assigns value to list.
Inputs: Inputs:
data (list): Data of type lis. data (list): Data of type lis.
...@@ -65,7 +64,7 @@ def _list_setitem_with_Tensor(data, number_index, value): ...@@ -65,7 +64,7 @@ def _list_setitem_with_Tensor(data, number_index, value):
value (Tensor): Value given. value (Tensor): Value given.
Outputs: Outputs:
List, type is same as the element type of data. list, type is same as the element type of data.
""" """
return F.list_setitem(data, number_index, value) return F.list_setitem(data, number_index, value)
...@@ -73,15 +72,15 @@ def _list_setitem_with_Tensor(data, number_index, value): ...@@ -73,15 +72,15 @@ def _list_setitem_with_Tensor(data, number_index, value):
@setitem.register("List", "Number", "List") @setitem.register("List", "Number", "List")
def _list_setitem_with_List(data, number_index, value): def _list_setitem_with_List(data, number_index, value):
""" """
Assign value to list. Assigns value to list.
Inputs: Inputs:
data (list): Data of type lis. data (list): Data of type lis.
number_index (Number): Index of data. number_index (Number): Index of data.
value (List): Value given. value (list): Value given.
Outputs: Outputs:
List, type is same as the element type of data. list, type is same as the element type of data.
""" """
return F.list_setitem(data, number_index, value) return F.list_setitem(data, number_index, value)
...@@ -89,15 +88,15 @@ def _list_setitem_with_List(data, number_index, value): ...@@ -89,15 +88,15 @@ def _list_setitem_with_List(data, number_index, value):
@setitem.register("Dictionary", "String", "Tensor") @setitem.register("Dictionary", "String", "Tensor")
def _dict_setitem_with_tensor(data, key, value): def _dict_setitem_with_tensor(data, key, value):
""" """
Assign value to dictionary. Assigns value to dictionary.
Inputs: Inputs:
data (Dictionary): Data of type dict. data (dict): Data of type dict.
key (str): Key of the data. key (str): Key of the data.
value (Tensor): Value given. value (Tensor): Value given.
Outputs: Outputs:
Dict, type is as same as the element type of data. dict, type is as same as the element type of data.
""" """
return F.dict_setitem(data, key, value) return F.dict_setitem(data, key, value)
...@@ -105,15 +104,15 @@ def _dict_setitem_with_tensor(data, key, value): ...@@ -105,15 +104,15 @@ def _dict_setitem_with_tensor(data, key, value):
@setitem.register("Dictionary", "String", "Number") @setitem.register("Dictionary", "String", "Number")
def _dict_setitem_with_number(data, key, value): def _dict_setitem_with_number(data, key, value):
""" """
Assign value to dictionary. Assigns value to dictionary.
Inputs: Inputs:
data (Dictionary): Data of type dict. data (dict): Data of type dict.
key (str): Key of the data. key (str): Key of the data.
value (Number): Value given. value (Number): Value given.
Outputs: Outputs:
Dict, type is as same as the element type of data. dict, type is as same as the element type of data.
""" """
return F.dict_setitem(data, key, value) return F.dict_setitem(data, key, value)
...@@ -219,14 +218,14 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value): ...@@ -219,14 +218,14 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value):
Tensor assignment. Tensor assignment.
Note: Note:
Syntax support: A[Slice] = U Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U
Restraint condition: A is a Tensor Restraint condition: A is a Tensor
Slice like "1:3, ::, :4:-1" Slice like "1:3, ::, :4:-1"
U is a Tensor(size=1) or Tensor(size>1) U is a Tensor(size=1) or Tensor(size>1)
Inputs: Inputs:
data (Tensor): Assigned tensor. data (Tensor): Assigned tensor.
input_slice (Tuple(Slice)): Slice expression. input_slice (Union[tuple[Slice], tuple[Number]]): Slice expression.
value (Number): Assignment value. value (Number): Assignment value.
Outputs: Outputs:
...@@ -236,22 +235,29 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value): ...@@ -236,22 +235,29 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value):
def _tensor_assgin_tensor(data, input_slice, value): def _tensor_assgin_tensor(data, input_slice, value):
"""Given a tensor value assign to tensor by slice""" """Assigns a tensor value to the tensor by slice."""
# 1. condition
result = None result = None
check_result = mult_util.check_tensor_setitem_index(input_slice) check_result = mult_util.check_tensor_setitem_index(input_slice)
if check_result: if check_result:
data_shape = F.shape(data) data_shape = F.shape(data)
indices = mult_util.slice2indices(input_slice, data_shape)
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = mult_util.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data) data_size = F.size(data)
data_dtype = F.dtype(data) data_dtype = F.dtype(data)
indices = mult_util.slice2indices(input_slice, data_shape)
indices_size = F.size(indices) indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, input_slice) indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(data_dtype, (indices_size,), 1) update = F.fill(data_dtype, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,)) condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_) condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape) condition = F.reshape(condition_1d, data_shape)
# 2. u
value_fill = None value_fill = None
value_size = F.size(value) value_size = F.size(value)
...@@ -264,10 +270,7 @@ def _tensor_assgin_tensor(data, input_slice, value): ...@@ -264,10 +270,7 @@ def _tensor_assgin_tensor(data, input_slice, value):
value_fill = F.reshape(value, (indices_size,)) value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,)) value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape) u = F.reshape(value_1d, data_shape)
# A[slice]= u -> A[B]=U -> select(B, U, A) return F.select(condition, u, data)
result = F.select(condition, u, data)
return result
@setitem.register("Tensor", "Slice", "Number") @setitem.register("Tensor", "Slice", "Number")
def _tensor_setitem_with_slice_v1(data, input_slice, value): def _tensor_setitem_with_slice_v1(data, input_slice, value):
...@@ -297,14 +300,14 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value): ...@@ -297,14 +300,14 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
Tensor assignment. Tensor assignment.
Note: Note:
Syntax support: A[Slice] = u Syntax support: A[tuple(Slice)] = u, and A[tuple(Number)] = u
Restraint condition: A is a Tensor. Restraint condition: A is a Tensor.
Slice like "1:3, ::, :4:-1" Slice like "1:3, ::, :4:-1"
u is a scalar u is a scalar
Inputs: Inputs:
data (Tensor): Assigned tensor. data (Tensor): Assigned tensor.
input_slice (Tuple(Slice)): slice expression. input_slice (Union[tuple[Slice], tuple[Number]]): slice expression.
value (Number): Assignment value. value (Number): Assignment value.
Outputs: Outputs:
...@@ -314,25 +317,46 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value): ...@@ -314,25 +317,46 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
def _tensor_assgin_number(data, input_slice, value): def _tensor_assgin_number(data, input_slice, value):
"""Given a scalar assign to tensor by slice""" """Givens a scalar assign to tensor by slice"""
# 1. condition
check_result = mult_util.check_tensor_setitem_index(input_slice) check_result = mult_util.check_tensor_setitem_index(input_slice)
result = None result = None
if check_result: if check_result:
data_shape = F.shape(data) data_shape = F.shape(data)
indices = mult_util.slice2indices(input_slice, data_shape)
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = mult_util.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data) data_size = F.size(data)
data_dtype = F.dtype(data) data_dtype = F.dtype(data)
indices = mult_util.slice2indices(input_slice, data_shape)
indices_size = F.size(indices) indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, input_slice) indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(data_dtype, (indices_size,), 1) update = F.fill(data_dtype, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,)) condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_) condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape) condition = F.reshape(condition_1d, data_shape)
# 2. u
value_fill = F.fill(data_dtype, (indices_size,), value) value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,)) value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape) u = F.reshape(value_1d, data_shape)
# A[slice]= u -> A[B]=U -> select(B, U, A) return F.select(condition, u, data)
result = F.select(condition, u, data)
return result
@setitem.register("Tensor", "Number", "Number")
def _tensor_setitem_with_int_v1(data, index, value):
"""Syntax: A[1] = 3"""
data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)
@setitem.register("Tensor", "Number", "Tensor")
def _tensor_setitem_with_int_v2(data, index, value):
"""Syntax: A[1] = Tensor"""
data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
...@@ -138,7 +138,7 @@ class TensorAssignWithSlice(Cell): ...@@ -138,7 +138,7 @@ class TensorAssignWithSlice(Cell):
z = a z = a
return z return z
def test_tensor_assign_with_slice(): def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice() net = TensorAssignWithSlice()
net2= TensorAssignWithSlice2() net2= TensorAssignWithSlice2()
...@@ -147,6 +147,7 @@ def test_tensor_assign_with_slice(): ...@@ -147,6 +147,7 @@ def test_tensor_assign_with_slice():
a = np.arange(60).reshape(3,4,5) a = np.arange(60).reshape(3,4,5)
b = Tensor([1]) b = Tensor([1])
Ta = Tensor(a) Ta = Tensor(a)
Ta4d = Tensor(a.reshape(1,3,4,5))
Tb= Tensor([1,3]) Tb= Tensor([1,3])
Tc= Tensor([]) Tc= Tensor([])
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) t = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
...@@ -184,6 +185,47 @@ def test_tensor_assign_with_slice(): ...@@ -184,6 +185,47 @@ def test_tensor_assign_with_slice():
with pytest.raises(ValueError): with pytest.raises(ValueError):
net_e1(Ta, 2) net_e1(Ta, 2)
net = TensorAssignWithInteger()
# Error for A[Number] = scalar/Tensor
# 1. A[Number] = U, U is a Tensor, u.size not match
with pytest.raises(ValueError):
net(Ta, Tb)
with pytest.raises(ValueError):
net(Ta, Tc)
# 2. A[Number] = U, the number index error
with pytest.raises(IndexError):
net(Ta4d, b)
# Error for A[(n,m)] = scalar/Tensor
# 1. A[(n,m)] = U, U is a tensor. u.size not match
net = TensorAssignWithTupleInteger()
with pytest.raises(ValueError):
net(Ta, Tc)
with pytest.raises(ValueError):
net(Ta, Tb)
# 2. A[(n,m)] = U, the number index error
with pytest.raises(IndexError):
net(Ta4d, b)
class TensorAssignWithInteger(Cell):
def __init__(self):
super(TensorAssignWithInteger, self).__init__()
def construct(self, a, b):
a[1] = 1
a[0] = b
return a
class TensorAssignWithTupleInteger(Cell):
def __init__(self):
super(TensorAssignWithTupleInteger, self).__init__()
def construct(self, a, b):
a[(1)] = 1
a[(1)] = b
a[(1,1)] = b
a[(1,1)] = 1
return a
class TensorAssignWithBoolTensorIndex(Cell): class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self): def __init__(self):
...@@ -273,6 +315,14 @@ def test_tensor_assign_bool_index(): ...@@ -273,6 +315,14 @@ def test_tensor_assign_bool_index():
net4(Ta, u_scalar) net4(Ta, u_scalar)
test_cases = [ test_cases = [
('TensorAssignWithTupleInteger', {
'block': TensorAssignWithTupleInteger(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithInteger', {
'block': TensorAssignWithInteger(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithSlice', { ('TensorAssignWithSlice', {
'block': TensorAssignWithSlice(), 'block': TensorAssignWithSlice(),
'desc_inputs': [Ta, u_tensor], 'desc_inputs': [Ta, u_tensor],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册