提交 a02eb240 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!595 Tensor assign with int or tuple(int) index

Merge pull request !595 from candanzg/tensor_assign_with_integer
......@@ -15,6 +15,7 @@
"""constexpr util"""
from functools import reduce
import numpy as np
from ...primitive import constexpr
from ....common.tensor import Tensor
......@@ -23,26 +24,27 @@ from ...._extends.utils import Slice
@constexpr
def check_equal(param1, param2, msg="{},{}"):
"""Checks whether the two parameters are equal or not."""
if param1 != param2:
raise ValueError(msg.format(param1, param2))
return param1
@constexpr
def check_tensor_setitem_index(index, element_type=None):
"""Check tuple index type of tensor assignment."""
"""Checks tuple index type of tensor assignment."""
if index is None:
raise ValueError("Tensor's index cannot be None.")
# eg. Tensor[Slice] = u
if isinstance(index, Slice):
return True
# eg. Tensor[Tuple] = u
# eg. Tensor[tuple] = u
if isinstance(index, tuple):
if not index:
raise ValueError("Tensor's index cannot be empty.")
# eg. Tensor[Tuple(Slice...)] = u
if not isinstance(index[0], Slice):
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
return True
# eg. Tensor[tuple(Slice...)] = u
if isinstance(index[0], (Slice, int)):
return True
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
# eg. Tensor[Tensor[dtype=bool]] = u
if index == mstype.tensor:
if element_type is None or element_type != mstype.bool_:
......@@ -57,7 +59,7 @@ def check_tensor_setitem_index(index, element_type=None):
@constexpr
def is_same_type(inst, type_):
"""
Check whether an object is an instance of a target type.
Checks whether an object is an instance of a target type.
Inputs:
inst (mindspore.dtype): Inspected type.
......@@ -69,34 +71,23 @@ def is_same_type(inst, type_):
return inst == type_
@constexpr
def error_msg(msg="", format_values=""):
"""
Used to throw exception information.
Inputs:
msg (str): information content.
"""
raise ValueError(msg.format(*format_values))
def slice_expand(input_slices, shape):
"""
Convert slice to indices.
Converts slice to indices.
Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple.
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (tuple): The shape of a sensor is an integer element tuple.
Outputs:
(List, List, List), This is expressed as (begins, ends, strides).
tuple[list], This is expressed as (begins, ends, strides).
"""
begin = []
end = []
strides = []
index = 0
slices = None
# Slice or Tuple(Slice...)
# Slice or tuple(Slice...)
if isinstance(input_slices, Slice):
slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice):
......@@ -119,14 +110,15 @@ def slice_expand(input_slices, shape):
index += 1
return begin, end, strides
@constexpr
def slice2indices(input_slices, shape):
"""
Convert slice to indices.
Converts slice to indices.
Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple.
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (tuple): The shape of a tensor is an integer element tuple.
Outputs:
Tensor, the shape is (n, 1).
......@@ -145,6 +137,7 @@ def slice2indices(input_slices, shape):
@constexpr
def check_indices(indices_size, index):
"""Checks indices whether is empty."""
if indices_size < 1:
raise ValueError("The tensor's index is unreasonable. index:{}".format(index))
return indices_size
......@@ -152,6 +145,7 @@ def check_indices(indices_size, index):
@constexpr
def check_indices_value_size(indices_size, value_size):
"""Checks if the sizes are already matched."""
if value_size < 1:
raise ValueError("The value assigned to tensor cannot be empty.")
if value_size > 1:
......@@ -160,3 +154,30 @@ def check_indices_value_size(indices_size, value_size):
"The value given to tensor does not match the index size. \
value size:{}, indics size:{}".format(value_size, indices_size))
return value_size
@constexpr
def integer_to_indices(index, shape):
"""Converts int or tuple[int] to indices."""
size = reduce(lambda x, y: x * y, shape)
range_ = np.arange(size).reshape(shape)
value = range_[index]
value = value.reshape(-1, 1)
return Tensor(value, dtype=mstype.int32)
@constexpr
def tuple_element_is_slice(indexs):
"""Judges tuple element type."""
if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], Slice):
return True
return False
@constexpr
def tuple_element_is_int(indexs):
"""Judges tuple element type."""
if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], int):
return True
return False
......@@ -25,15 +25,14 @@ setitem = base.MultitypeFuncGraph('setitem')
@setitem.register("List", "Number", "String")
def _list_setitem_with_string(data, number_index, value):
"""
Assign value to list.
Assigns value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (String): Value given.
Outputs:
List, type is same as the element type of data.
list, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
......@@ -41,7 +40,7 @@ def _list_setitem_with_string(data, number_index, value):
@setitem.register("List", "Number", "Number")
def _list_setitem_with_number(data, number_index, value):
"""
Assign value to list.
Assigns value to list.
Inputs:
data (list): Data of type lis.
......@@ -49,7 +48,7 @@ def _list_setitem_with_number(data, number_index, value):
value (Number): Value given.
Outputs:
List, type is same as the element type of data.
list, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
......@@ -57,7 +56,7 @@ def _list_setitem_with_number(data, number_index, value):
@setitem.register("List", "Number", "Tensor")
def _list_setitem_with_Tensor(data, number_index, value):
"""
Assign value to list.
Assigns value to list.
Inputs:
data (list): Data of type lis.
......@@ -65,7 +64,7 @@ def _list_setitem_with_Tensor(data, number_index, value):
value (Tensor): Value given.
Outputs:
List, type is same as the element type of data.
list, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
......@@ -73,15 +72,15 @@ def _list_setitem_with_Tensor(data, number_index, value):
@setitem.register("List", "Number", "List")
def _list_setitem_with_List(data, number_index, value):
"""
Assign value to list.
Assigns value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (List): Value given.
value (list): Value given.
Outputs:
List, type is same as the element type of data.
list, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
......@@ -89,15 +88,15 @@ def _list_setitem_with_List(data, number_index, value):
@setitem.register("Dictionary", "String", "Tensor")
def _dict_setitem_with_tensor(data, key, value):
"""
Assign value to dictionary.
Assigns value to dictionary.
Inputs:
data (Dictionary): Data of type dict.
data (dict): Data of type dict.
key (str): Key of the data.
value (Tensor): Value given.
Outputs:
Dict, type is as same as the element type of data.
dict, type is as same as the element type of data.
"""
return F.dict_setitem(data, key, value)
......@@ -105,15 +104,15 @@ def _dict_setitem_with_tensor(data, key, value):
@setitem.register("Dictionary", "String", "Number")
def _dict_setitem_with_number(data, key, value):
"""
Assign value to dictionary.
Assigns value to dictionary.
Inputs:
data (Dictionary): Data of type dict.
data (dict): Data of type dict.
key (str): Key of the data.
value (Number): Value given.
Outputs:
Dict, type is as same as the element type of data.
dict, type is as same as the element type of data.
"""
return F.dict_setitem(data, key, value)
......@@ -219,14 +218,14 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value):
Tensor assignment.
Note:
Syntax support: A[Slice] = U
Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U
Restraint condition: A is a Tensor
Slice like "1:3, ::, :4:-1"
U is a Tensor(size=1) or Tensor(size>1)
Inputs:
data (Tensor): Assigned tensor.
input_slice (Tuple(Slice)): Slice expression.
input_slice (Union[tuple[Slice], tuple[Number]]): Slice expression.
value (Number): Assignment value.
Outputs:
......@@ -236,39 +235,43 @@ def _tensor_setitem_with_slice_v4(data, input_slice, value):
def _tensor_assgin_tensor(data, input_slice, value):
"""Given a tensor value assign to tensor by slice"""
# 1. condition
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = mult_util.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
data_size = F.size(data)
data_dtype = F.dtype(data)
indices = mult_util.slice2indices(input_slice, data_shape)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, input_slice)
update = F.fill(data_dtype, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape)
# 2. u
value_fill = None
value_size = F.size(value)
value_size = mult_util.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
# A[slice]= u -> A[B]=U -> select(B, U, A)
result = F.select(condition, u, data)
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = mult_util.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(data_dtype, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape)
value_fill = None
value_size = F.size(value)
value_size = mult_util.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
@setitem.register("Tensor", "Slice", "Number")
def _tensor_setitem_with_slice_v1(data, input_slice, value):
"""
......@@ -297,14 +300,14 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
Tensor assignment.
Note:
Syntax support: A[Slice] = u
Syntax support: A[tuple(Slice)] = u, and A[tuple(Number)] = u
Restraint condition: A is a Tensor.
Slice like "1:3, ::, :4:-1"
u is a scalar
Inputs:
data (Tensor): Assigned tensor.
input_slice (Tuple(Slice)): slice expression.
input_slice (Union[tuple[Slice], tuple[Number]]): slice expression.
value (Number): Assignment value.
Outputs:
......@@ -314,25 +317,46 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value):
def _tensor_assgin_number(data, input_slice, value):
"""Given a scalar assign to tensor by slice"""
# 1. condition
"""Givens a scalar assign to tensor by slice"""
check_result = mult_util.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
data_size = F.size(data)
data_dtype = F.dtype(data)
indices = mult_util.slice2indices(input_slice, data_shape)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, input_slice)
update = F.fill(data_dtype, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape)
# 2. u
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
# A[slice]= u -> A[B]=U -> select(B, U, A)
result = F.select(condition, u, data)
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = mult_util.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(data_dtype, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
@setitem.register("Tensor", "Number", "Number")
def _tensor_setitem_with_int_v1(data, index, value):
"""Syntax: A[1] = 3"""
data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)
@setitem.register("Tensor", "Number", "Tensor")
def _tensor_setitem_with_int_v2(data, index, value):
"""Syntax: A[1] = Tensor"""
data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
......@@ -139,7 +139,7 @@ class TensorAssignWithSlice(Cell):
z = a
return z
def test_tensor_assign_with_slice():
def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice()
net2= TensorAssignWithSlice2()
......@@ -148,6 +148,7 @@ def test_tensor_assign_with_slice():
a = np.arange(60).reshape(3,4,5)
b = Tensor([1])
Ta = Tensor(a)
Ta4d = Tensor(a.reshape(1,3,4,5))
Tb= Tensor([1,3])
Tc= Tensor([])
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
......@@ -185,6 +186,47 @@ def test_tensor_assign_with_slice():
with pytest.raises(ValueError):
net_e1(Ta, 2)
net = TensorAssignWithInteger()
# Error for A[Number] = scalar/Tensor
# 1. A[Number] = U, U is a Tensor, u.size not match
with pytest.raises(ValueError):
net(Ta, Tb)
with pytest.raises(ValueError):
net(Ta, Tc)
# 2. A[Number] = U, the number index error
with pytest.raises(IndexError):
net(Ta4d, b)
# Error for A[(n,m)] = scalar/Tensor
# 1. A[(n,m)] = U, U is a tensor. u.size not match
net = TensorAssignWithTupleInteger()
with pytest.raises(ValueError):
net(Ta, Tc)
with pytest.raises(ValueError):
net(Ta, Tb)
# 2. A[(n,m)] = U, the number index error
with pytest.raises(IndexError):
net(Ta4d, b)
class TensorAssignWithInteger(Cell):
def __init__(self):
super(TensorAssignWithInteger, self).__init__()
def construct(self, a, b):
a[1] = 1
a[0] = b
return a
class TensorAssignWithTupleInteger(Cell):
def __init__(self):
super(TensorAssignWithTupleInteger, self).__init__()
def construct(self, a, b):
a[(1)] = 1
a[(1)] = b
a[(1,1)] = b
a[(1,1)] = 1
return a
class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self):
......@@ -274,6 +316,14 @@ def test_tensor_assign_bool_index():
net4(Ta, u_scalar)
test_cases = [
('TensorAssignWithTupleInteger', {
'block': TensorAssignWithTupleInteger(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithInteger', {
'block': TensorAssignWithInteger(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithSlice', {
'block': TensorAssignWithSlice(),
'desc_inputs': [Ta, u_tensor],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册