提交 bc0455ea 编写于 作者: H huangdongrun

* fix bool index

* change slice setitem to mixed procedure
* add testcase for slice assignment
上级 0478b7d1
......@@ -167,12 +167,13 @@ def _tensor_getitem(self, index):
return tensor_index_by_tensor(self, index)
if isinstance(index, tuple):
return tensor_index_by_tuple(self, index)
# bool type should be judged before int
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if isinstance(index, int):
return _tensor_index_by_integer(self, index)
if isinstance(index, slice):
return tensor_index_by_slice(self, index)
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if index is None:
return F.expand_dims(self, 0)
if index is ...:
......@@ -206,7 +207,8 @@ def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor cannot be 0.")
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor\
cannot be 0.")
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index)
return F.strided_slice(data, begin_strides, end_strides, step_strides)
......@@ -215,7 +217,11 @@ def _tensor_index_by_integer(data, number):
"""Tensor getitem by a single integer number"""
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by an integer, the dimension of the tensor cannot be 0.")
return const_utils.raise_index_error("When tensor is indexed by an integer,\
the dimension of the tensor cannot be 0.")
if number >= shape[0]:
return const_utils.raise_index_error("index {} is out of bounds for axis 0 with size {}".format(
number, shape[0]))
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(shape, number)
shrink_axis_mask = 1
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
......@@ -427,8 +433,6 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return tensor_setitem_by_slice_with_number(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
......@@ -488,8 +492,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return tensor_setitem_by_slice_with_tensor(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
......
......@@ -339,6 +339,8 @@ def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
@constexpr
def generate_broadcast_shape(shapes, op_name):
"""Generate broadcast shape for a tuple of shape."""
if not shapes:
return ()
broadcast_shape = shapes[0]
for i, shape in enumerate(shapes):
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
......@@ -541,6 +543,11 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
slice_indexes[slice_count].step)
# Use list to represent slicing result.
indexes_info[pos] = list(range(data_shape[pos]))[slice_obj]
if not indexes_info[pos]:
raise IndexError("An empty slice is not supported, got {}:{}:{}".format(
slice_indexes[slice_count].start,
slice_indexes[slice_count].stop,
slice_indexes[slice_count].step))
slice_count += 1
elif isinstance(ele_type, mstype.ellipsis_type):
if ellipsis_num != 0:
......
......@@ -646,7 +646,7 @@ class TensorAssignWithSlice2(Cell):
class TensorAssignWithSlice(Cell):
def __init__(self):
super(TensorAssignWithSlice, self).__init__()
self.c = 2
self.c = 2.0
def construct(self, a, b, ck):
a[1:3, ::] = b
......@@ -661,7 +661,47 @@ class TensorAssignWithSlice(Cell):
return z
def test_tensor_assign():
def test_tensor_assign_slice_value_1():
net = TensorAssignWithSlice()
a = np.arange(60).reshape(3, 4, 5)
ck = np.arange(60).reshape(3, 4, 5)
b = np.array([1]).astype(np.float32) # Tensor([1], dtype=mstype.float32)
tb = Tensor(b, dtype=mstype.float32)
ta = Tensor(a, dtype=mstype.float32)
tck = Tensor(ck, dtype=mstype.float32)
out = net(ta, tb, tck)
a[1:3, ::] = b
a[2:3:, 3:] = b
a[::] = b
a[::] = 2.0
a[::, ::] = b
a[::, ::] = 2.0
a[2:3:, 0:, 4:1:-1] = b
a[2:3:, 0:, 4:1:-1] = 2.0
z = a + ck
assert np.all(z == out.asnumpy())
def test_tensor_assign_slice_value_2():
net2 = TensorAssignWithSlice2()
a = np.array([1, 2, 3, 4, 5, 6, 7, 8])
ck = np.array([1, 2, 3, 4, 5, 6, 7, 8])
b = np.array([1]).astype(np.float32) # Tensor([1], dtype=mstype.float32)
tb = Tensor(b, dtype=mstype.float32)
ta = Tensor(a, dtype=mstype.float32)
tck = Tensor(ck, dtype=mstype.float32)
out = net2(ta, tb, tck)
a[1:5] = b
a[3:4] = 5
a[-1:1:-1] = b
a[-1:3:-1] = 5
a[::] = b
a[::] = 9
z = a + ck
assert np.all(z == out.asnumpy())
def test_tensor_assign_exception():
net = TensorAssignWithSlice()
net2 = TensorAssignWithSlice2()
net_e1 = TensorAssignWithSliceError1()
......@@ -677,8 +717,6 @@ def test_tensor_assign():
Tc = Tensor([], dtype=mstype.float32)
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
net(Ta, b, Tck)
net2(t, b, tck)
# Error for A[Slice] = Number
# 1. A[Slice] = Number, Slice error
with pytest.raises(IndexError):
......@@ -744,9 +782,6 @@ def test_tensor_assign():
# 2. A[::, 1:, ...] = scalar/tensor
net = TensorAssignWithTupleEllipsis()
net(Ta, b)
Tc = Tensor(1, mstype.float32)
with pytest.raises(ValueError):
net(Ta, Tc)
with pytest.raises(ValueError):
net(Ta, Tb)
......@@ -765,7 +800,7 @@ class TensorAssignWithTupleEllipsis(Cell):
super(TensorAssignWithTupleEllipsis, self).__init__()
def construct(self, a, b):
a[:2, ...] = 1
a[:2, ...] = 1.0
a[1:, ...] = b
return a
......@@ -955,3 +990,16 @@ def Xtest_tensor_slice_reduce_out_of_bounds_positive():
with pytest.raises(ValueError) as ex:
net(input_tensor)
assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)
def test_tensor_range():
a = np.arange(4*5*6).reshape(4, 5, 6).astype(np.float32)
ta = Tensor(a, mstype.float32)
ms_out = []
for item in ta:
ms_out.append(item)
np_out = []
for item in a:
np_out.append(item)
for i, elem in enumerate(ms_out):
assert np.all(elem.asnumpy() == np_out[i])
......@@ -130,7 +130,7 @@ class TensorAssignWithSlice2(Cell):
class TensorAssignWithSlice(Cell):
def __init__(self):
super(TensorAssignWithSlice, self).__init__()
self.c = 2
self.c = 2.0
def construct(self, a, b, ck):
a[1:3, ::] = b
......@@ -528,8 +528,7 @@ def test_tensor_assign():
net = TensorAssignWithTupleEllipsis()
net(Ta, b)
Tc = Tensor(1, mstype.float32)
with pytest.raises(ValueError):
net(Ta, Tc)
net(Ta, Tc)
with pytest.raises(ValueError):
net(Ta, Tb)
......@@ -548,7 +547,7 @@ class TensorAssignWithTupleEllipsis(Cell):
super(TensorAssignWithTupleEllipsis, self).__init__()
def construct(self, a, b):
a[:2, ...] = 1
a[:2, ...] = 1.0
a[1:, ...] = b
return a
......@@ -579,10 +578,10 @@ class TensorAssignWithTupleInteger(Cell):
super(TensorAssignWithTupleInteger, self).__init__()
def construct(self, a, b, ck):
a[(1)] = 1
a[(1)] = 1.0
a[(1)] = b
a[(1, 1)] = b
a[(1, 1)] = 1
a[(1, 1)] = 1.0
z = a + ck
return z
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册