From 850171a34bc53d469351f787db57af567b989855 Mon Sep 17 00:00:00 2001 From: buxue Date: Fri, 22 May 2020 15:55:38 +0800 Subject: [PATCH] Restrict tensor getitem or setitem not support mixed tensor. --- .../ops/composite/multitype_ops/_utils.py | 4 +- .../composite/multitype_ops/getitem_impl.py | 2 +- .../composite/multitype_ops/setitem_impl.py | 6 +- tests/ut/python/ops/test_tensor_slice.py | 77 +++++++++++++------ 4 files changed, 61 insertions(+), 28 deletions(-) diff --git a/mindspore/ops/composite/multitype_ops/_utils.py b/mindspore/ops/composite/multitype_ops/_utils.py index cff88dfdb..9717ddc15 100644 --- a/mindspore/ops/composite/multitype_ops/_utils.py +++ b/mindspore/ops/composite/multitype_ops/_utils.py @@ -254,7 +254,7 @@ def tuple_element_is_int(indexs): @constexpr -def tuple_elements_type(types): +def tuple_index_elements_type(types, op_name): """Judges the type of all elements of the tuple.""" tensors_number = 0 for ele in types: @@ -264,7 +264,7 @@ def tuple_elements_type(types): return ALL_TENSOR if tensors_number == 0: return NO_TENSOR - return CONTAIN_TENSOR + raise IndexError(f"For '{op_name}', the index does not support mixed tensor.") @constexpr diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 5e217ba1b..affbb1929 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -247,7 +247,7 @@ def _tensor_getitem_by_tuple(data, tuple_index): Tensor, element type is same as the element type of data. """ index_types = multi_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = multi_utils.tuple_elements_type(index_types) + index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_GETITEM) result = None if index_elements_type == multi_utils.NO_TENSOR: result = _tensor_slice(data, tuple_index) diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 49a1666de..f51dc12c2 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -191,7 +191,7 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): Tensor, element type and shape is same as data. """ index_types = multi_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = multi_utils.tuple_elements_type(index_types) + index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) result = None if index_elements_type == multi_utils.NO_TENSOR: result = _tensor_assgin_number(data, tuple_index, value) @@ -222,7 +222,7 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): Tensor, element type and shape is same as data. """ index_types = multi_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = multi_utils.tuple_elements_type(index_types) + index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) result = None if index_elements_type == multi_utils.NO_TENSOR: result = _tensor_assgin_tensor(data, tuple_index, value) @@ -254,7 +254,7 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): Tensor, element type and shape is same as data. """ index_types = multi_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = multi_utils.tuple_elements_type(index_types) + index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) result = None if index_elements_type == multi_utils.ALL_TENSOR: indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 6b4c84f14..776c43b78 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -146,9 +146,9 @@ class TensorAssignWithSlice(Cell): return z -class TensorIndexByOneTensor(Cell): +class TensorGetItemByOneTensor(Cell): def __init__(self): - super(TensorIndexByOneTensor, self).__init__() + super(TensorGetItemByOneTensor, self).__init__() self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32) def construct(self, x, index): @@ -156,9 +156,9 @@ class TensorIndexByOneTensor(Cell): return ret -class TensorIndexByTwoTensors(Cell): +class TensorGetItemByTwoTensors(Cell): def __init__(self): - super(TensorIndexByTwoTensors, self).__init__() + super(TensorGetItemByTwoTensors, self).__init__() self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32) def construct(self, x, index_0, index_1): @@ -166,9 +166,9 @@ class TensorIndexByTwoTensors(Cell): return ret -class TensorIndexByThreeTensors(Cell): +class TensorGetItemByThreeTensors(Cell): def __init__(self): - super(TensorIndexByThreeTensors, self).__init__() + super(TensorGetItemByThreeTensors, self).__init__() self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32) def construct(self, x, index_0, index_1, index_2): @@ -176,6 +176,15 @@ class TensorIndexByThreeTensors(Cell): return ret +class TensorGetItemByMixedTensors(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors, self).__init__() + + def construct(self, x, index_0, index_1): + ret = x[index_0, index_1, 0:6] + return ret + + class TensorSetItemByOneTensorWithNumber(Cell): def __init__(self, value): super(TensorSetItemByOneTensorWithNumber, self).__init__() @@ -300,6 +309,19 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): return ret +class TensorSetItemByMixedTensors(Cell): + def __init__(self): + super(TensorSetItemByMixedTensors, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.value = 99.0 + + def construct(self, index_0, index_1): + self.param[index_0, index_1, 0:6] = self.value + ret = self.param + self.const + return ret + + def test_tensor_assign(): context.set_context(mode=context.GRAPH_MODE, save_graphs=True) net = TensorAssignWithSlice() @@ -596,19 +618,19 @@ test_cases = [ 'block': NetWorkSliceEllipsis(), 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], }), - ('TensorIndexByOneTensor', { - 'block': TensorIndexByOneTensor(), + ('TensorGetItemByOneTensor', { + 'block': TensorGetItemByOneTensor(), 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)], }), - ('TensorIndexByTwoTensors', { - 'block': TensorIndexByTwoTensors(), + ('TensorGetItemByTwoTensors', { + 'block': TensorGetItemByTwoTensors(), 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], }), - ('TensorIndexByThreeTensors', { - 'block': TensorIndexByThreeTensors(), + ('TensorGetItemByThreeTensors', { + 'block': TensorGetItemByThreeTensors(), 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), @@ -665,37 +687,43 @@ test_cases = [ ] raise_error_set = [ - ('TensorIndexByOneTensorDtypeError', { - 'block': (TensorIndexByOneTensor(), {'exception': TypeError}), + ('TensorGetItemByOneTensorDtypeError', { + 'block': (TensorGetItemByOneTensor(), {'exception': TypeError}), 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)], }), - ('TensorIndexByTwoTensorsShapeError', { - 'block': (TensorIndexByTwoTensors(), {'exception': ValueError}), + ('TensorGetItemByTwoTensorsShapeError', { + 'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}), 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)], }), - ('TensorIndexByTwoTensorsDtypeError', { - 'block': (TensorIndexByTwoTensors(), {'exception': TypeError}), + ('TensorGetItemByTwoTensorsDtypeError', { + 'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}), 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)], }), - ('TensorIndexByThreeTensorsShapeError', { - 'block': (TensorIndexByThreeTensors(), {'exception': ValueError}), + ('TensorGetItemByThreeTensorsShapeError', { + 'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}), 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)], }), - ('TensorIndexByThreeTensorsDtypeError', { - 'block': (TensorIndexByThreeTensors(), {'exception': TypeError}), + ('TensorGetItemByThreeTensorsDtypeError', { + 'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}), 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64), Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], }), + ('TensorGetItemByMixedTensors', { + 'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)], + }), ('TensorSetItemByOneTensorWithNumberTypeError', { 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}), 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], @@ -781,6 +809,11 @@ raise_error_set = [ Tensor(np.zeros((4, 5)), mstype.float32), Tensor(np.ones((4, 5)), mstype.int32), Tensor(np.ones((4, 5)) * 2, mstype.int32)], + }), + ('TensorSetItemByMixedTensors', { + 'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], }) ] -- GitLab