提交 b9ad407d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2204 add check for tensor slice before convert to ops

Merge pull request !2204 from zhangbuxue/add_check_for_tensor_slice_before_convert_to_ops
......@@ -156,7 +156,6 @@ def generate_updates_from_tensor(data, index, value, op_type):
return value
def tensor_getitem(self, index):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
......@@ -164,16 +163,15 @@ def tensor_getitem(self, index):
if isinstance(index, tuple):
return tensor_index_by_tuple(self, index)
if isinstance(index, int):
return tensor_index_by_number(self, index)
return tensor_index_by_integer(self, index)
if isinstance(index, slice):
return tensor_index_by_slice(self, index)
if isinstance(index, bool):
return tensor_index_by_bool(self, index)
if index is ...:
return self
raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32,\
got {} with type{}".format(index, type(index)))
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, "
f"got {index} with type {type(index)}.")
tensor_operator_registry.register("__getitem__", tensor_getitem)
......@@ -199,13 +197,19 @@ def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(F.shape(data), slice_index)
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor cannot be 0.")
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index)
return F.strided_slice(data, begin_strides, end_strides, step_strides)
def tensor_index_by_integer(data, number):
"""Tensor getitem by a single integer number"""
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(F.shape(data), number)
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by an integer, the dimension of the tensor cannot be 0.")
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(shape, number)
shrink_axis_mask = 1
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
......@@ -214,7 +218,7 @@ def tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value"""
if bool_value:
return F.expand_dims(data, 0)
return const_utils.raise_index_error("bool value as indexing ,false is not supported")
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
def tensor_index_by_number(data, number):
......@@ -224,7 +228,7 @@ def tensor_index_by_number(data, number):
return tensor_index_by_bool(data, number)
if number_type == const_utils.INT_:
return tensor_index_by_integer(data, number)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool")
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")
def tensor_index_by_tensor(data, tensor_index):
......@@ -233,13 +237,18 @@ def tensor_index_by_tensor(data, tensor_index):
const_utils.TENSOR_GETITEM)
if dtype_valid:
return F.gather(data, tensor_index, 0)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool")
return const_utils.raise_index_error("For 'tensor getitem', "
"the index tensor data type only support mstype.int32.")
def tensor_index_by_tuple_slice(data, t):
"""Tensor getitem by a tuple of slice"""
shape = F.shape(data)
if len(t) > len(shape):
const_utils.raise_index_error("When tensor is indexed by a tuple, "
"the length of the tuple cannot be greater than the dimension of the tensor.")
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(F.shape(data), t)
const_utils.get_stride_info_from_tuple(shape, t)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
......
......@@ -47,8 +47,8 @@ def _div_tensor(x, y):
Two tensors divide by element.
Args:
x (Tensor): x
y (Tensor): The dtype is same as x.
x (Tensor): The first input tensor.
y (Tensor): The second input tensor.
Returns:
Tensor, has the same dtype as x.
......
......@@ -34,7 +34,7 @@ def _floordiv_scalar(x, y):
@floordiv.register("Tensor", "Tensor")
def _floordiv_tensor(x, y):
"""Returns x // y where x and y are all tensors and have save dtype."""
"""Returns x // y where x and y are all tensors."""
return F.tensor_floordiv(x, y)
......
......@@ -164,7 +164,7 @@ def _tensor_getitem_by_number(data, number_index):
@getitem.register("Tensor", "None")
def _tensor_getitem_by_none(data, index):
"""
For none indexing , expand data with one dim
For none indexing , expand data with one dim.
Inputs:
data (Tensor): A tensor.
......
......@@ -25,7 +25,7 @@ greater_equal = base.MultitypeFuncGraph("greater_equal")
@greater_equal.register("Number", "Number")
def _greater_equal_scala(x, y):
"""
Determine whether x is greater equal than y
Determine whether x is greater equal than y.
Args:
x(Number): Number.
......
......@@ -48,6 +48,6 @@ def _greater_tensor(x, y):
y(Tensor): Tensor.
Returns:
tensor, return operation of x and y by P.Greater
tensor, return operation of x and y by P.Greater.
"""
return F.tensor_gt(x, y)
......@@ -25,7 +25,7 @@ less_equal = base.MultitypeFuncGraph("less_equal")
@less_equal.register("Number", "Number")
def _less_equal_scala(x, y):
"""
Determine whether x is less equal than y
Determine whether x is less equal than y.
Args:
x(Number): Number.
......@@ -41,7 +41,7 @@ def _less_equal_scala(x, y):
@less_equal.register("Tensor", "Tensor")
def _less_equal_tensor(x, y):
"""
Determine whether tensor x is less equal than tensor y elementwise
Determine whether tensor x is less equal than tensor y elementwise.
Args:
x(Tensor): Tensor.
......
......@@ -25,13 +25,13 @@ logical_not = base.MultitypeFuncGraph("logical_not")
@logical_not.register("Number")
def _logical_not_scala(x):
"""
Return logical not operation result of x
Return logical not operation result of x.
Args:
x(Number): Number.
Returns:
bool, Return logical not operation result of x
bool, Return logical not operation result of x.
"""
return F.bool_not(x.__bool__())
......@@ -39,10 +39,24 @@ def _logical_not_scala(x):
@logical_not.register("Tensor")
def _logical_not_tensor(x):
"""
Return logical not operation result of x
Return logical not operation result of x.
Args:
x(Tensor): Tensor.
Returns:
Tensor, Return logical not operation result of x
Tensor, Return logical not operation result of x.
"""
return F.logical_not(x)
return F.logical_not(x)
@logical_not.register("Tuple")
def _logical_not_tuple(x):
"""
Return logical not operation result of a tuple object.
Args:
x(Tuple): The input tuple.
Returns:
bool, Return logical not operation result of x.
"""
return F.bool_not(x.__bool__())
......@@ -25,14 +25,14 @@ logical_and = base.MultitypeFuncGraph("logical_and")
@logical_and.register("Number", "Number")
def _logical_and_scala(x, y):
"""
Return logical and operation result of x and y
Return logical and operation result of x and y.
Args:
x(Number): Number.
y(Number): Number.
Returns:
bool, Return logical and operation result of x and y
bool, Return logical and operation result of x and y.
"""
return F.bool_and(x.__bool__(), y.__bool__())
......@@ -40,13 +40,13 @@ def _logical_and_scala(x, y):
@logical_and.register("Tensor", "Tensor")
def _logical_and_tensor(x, y):
"""
Return logical and operation result of x and y
Return logical and operation result of x and y.
Args:
x(Tensor): Tensor.
y(Tensor): Tensor.
Returns:
Tensor, Return logical and operation result of x and y
Tensor, Return logical and operation result of x and y.
"""
return F.logical_and(x, y)
......@@ -25,14 +25,14 @@ logical_or = base.MultitypeFuncGraph("logical_or")
@logical_or.register("Number", "Number")
def _logical_or_scala(x, y):
"""
Return logical or operation result of x and y
Return logical or operation result of x and y.
Args:
x(Number): Number.
y(Number): Number.
Returns:
bool, Return logical or operation result of x and y
bool, Return logical or operation result of x and y.
"""
return F.bool_or(x.__bool__(), y.__bool__())
......@@ -40,13 +40,13 @@ def _logical_or_scala(x, y):
@logical_or.register("Tensor", "Tensor")
def _logical_or_tensor(x, y):
"""
Return logical operation or result of x and y
Return logical operation or result of x and y.
Args:
x(Tensor): Tensor.
y(Tensor): Tensor.
Returns:
Tensor, Return logical operation or result of x and y
Tensor, Return logical operation or result of x and y.
"""
return F.logical_or(x, y)
return F.logical_or(x, y)
......@@ -34,7 +34,7 @@ def _mod_scalar(x, y):
@mod.register("Tensor", "Tensor")
def _mod_tensor(x, y):
"""Returns x % y where x and y are all tensors and have save dtype."""
"""Returns x % y where x and y are all tensors."""
return F.tensor_mod(x, y)
......
......@@ -40,7 +40,7 @@ def _mul_scalar(x, y):
@mul.register("Tensor", "Tensor")
def _mul_tensor(x, y):
"""
Returns x * y by element-wise where x and y are all tensors and have same dtype.
Returns x * y by element-wise where x and y are all tensors.
Outputs:
Tensor, has the same dtype as x.
......
......@@ -34,7 +34,7 @@ def _sub_scalar(x, y):
@sub.register("Tensor", "Tensor")
def _sub_tensor(x, y):
"""Returns x - y where x and y are all tensors and have save dtype."""
"""Returns x - y where x and y are all tensors."""
return F.tensor_sub(x, y)
......
......@@ -1139,7 +1139,7 @@ raise_error_set = [
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_exec():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
return test_cases
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册