diff --git a/mindspore/_extends/parse/__init__.py b/mindspore/_extends/parse/__init__.py index f8a34057c5262cf0f1d21209415e4d00a1387757..9366b5a2d23d05579c8b75e4bb06c5c53944cf7e 100644 --- a/mindspore/_extends/parse/__init__.py +++ b/mindspore/_extends/parse/__init__.py @@ -18,7 +18,7 @@ Interfaces for parser module in c++. from .parser import (Parser, create_obj_instance, generate_scope, get_bprop_method_of_class, get_class_instance_type, - get_class_member_namespace_symbol, + get_class_member_namespace_symbol, create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_module_namespace, get_obj_type, get_object_key, get_parse_method_of_class, get_scope_name, @@ -29,4 +29,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type', 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', - 'get_dataclass_methods', 'get_scope_name'] + 'get_dataclass_methods', 'get_scope_name', 'create_slice_obj'] diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index e88c9c15e92617c5a9a34467e8b764ea4fcbd3f9..d8039cd56a8f2a94f53f3337e85cc72acdeb29ce 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -29,6 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype from mindspore.common.api import _MindSporeFunction from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT +from ..utils import Slice # define return value RET_SUCCESS = 0 @@ -69,6 +70,10 @@ parse_expr_statement_white_list = ( "append", ) +def create_slice_obj(start, end, step): + """Create Slice object""" + return Slice(start, end, step) + def parse_cb(func, parse_method=None): """Implements the function of parse.""" diff --git a/mindspore/_extends/utils.py b/mindspore/_extends/utils.py index 8469ddda8b82710ebbb468b79271fdba2a6c273d..d0457607b5fa17280c50e1e78f730a474adb8ae4 100644 --- a/mindspore/_extends/utils.py +++ b/mindspore/_extends/utils.py @@ -19,6 +19,7 @@ import logging import os import inspect from functools import wraps +from dataclasses import dataclass def cal_sha256(file_path): @@ -99,3 +100,13 @@ def cell_attr_register(fn=None, attrs=None): if fn is not None: return wrap_cell(fn) return wrap_cell + + +@dataclass +class Slice: + """ + Slice class + """ + start: int + end: int + step: int diff --git a/mindspore/ccsrc/ir/value.h b/mindspore/ccsrc/ir/value.h index c80e22f735f39ee3d52ddbda2983ec3cf0db1a16..160eac7b5c51c5c605aeb4e1f541beab1aee09fd 100644 --- a/mindspore/ccsrc/ir/value.h +++ b/mindspore/ccsrc/ir/value.h @@ -123,6 +123,9 @@ class ValueSlice : public Value { abstract::AbstractBasePtr ToAbstract() override; std::string DumpText() const override { return ToString(); } + ValuePtr start() const { return start_; } + ValuePtr stop() const { return stop_; } + ValuePtr step() const { return step_; } private: ValuePtr start_; diff --git a/mindspore/ccsrc/pipeline/parse/parse_base.h b/mindspore/ccsrc/pipeline/parse/parse_base.h index aad8be0d6e96269b557fda4d2ad33be3e94b2b91..a3ca67b60a267b705826c6eca4e2a8773cc473c4 100644 --- a/mindspore/ccsrc/pipeline/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/parse/parse_base.h @@ -79,6 +79,8 @@ const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement"; const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; +const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; + // define the common name const char NAMED_PRIMITIVE_ITER[] = "iter"; const char NAMED_PRIMITIVE_NEXT[] = "next"; diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 233d5df305594ab140a5f054f604d8fc3aedcc33..46e088ab11a6ef4c236046667a3bdb2de1e0f8e8 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -289,6 +289,13 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic["shape"] = shape; dic["dtype"] = abs_base->BuildType(); dic["value"] = BuildValue(abs_base->BuildValue()); + } else if (abs_base->isa()) { + auto arg_slice = dyn_cast(abs_base); + std::vector shape; + dic["shape"] = shape; + dic["dtype"] = arg_slice->BuildType(); + dic["value"] = BuildValue(arg_slice->BuildValue()); + } else if (abs_base->isa()) { auto arg_tuple = dyn_cast(abs_base); size_t len = arg_tuple->size(); diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index e840ff87344e6989ea890ff30289c3cfdfa6b886..049c1dcdb86ae274cf12c63bc55bffa029a2e02f 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -28,6 +28,7 @@ #include "ir/meta_tensor.h" #include "pipeline/parse/parse.h" +#include "pipeline/parse/parse_base.h" #include "ir/value.h" namespace mindspore { @@ -97,6 +98,13 @@ py::object ValuePtrToPyData(const ValuePtr &value) { i++; } ret = rets; + } else if (value->isa()) { + auto slice = value->cast(); + auto start = ValuePtrToPyData(slice->start()); + auto end = ValuePtrToPyData(slice->stop()); + auto step = ValuePtrToPyData(slice->step()); + ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end, + step); } else if (value->isa()) { py::tuple v(1); v[0] = value->cast(); diff --git a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py index b3687c553c9f36cf33716fc0482b12a7b58acc74..49773ff8ad81aa0431ad7c690753c918bbb3b99e 100644 --- a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py +++ b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py @@ -15,7 +15,43 @@ """constexpr util""" +import numpy as np from ...primitive import constexpr +from ....common.tensor import Tensor +from ....common import dtype as mstype +from ...._extends.utils import Slice + +@constexpr +def check_equal(param1, param2, msg="{},{}"): + if param1 != param2: + raise ValueError(msg.format(param1, param2)) + return param1 + +@constexpr +def check_tensor_setitem_index(index, element_type=None): + """Check tuple index type of tensor assignment.""" + if index is None: + raise ValueError("Tensor's index cannot be None.") + # eg. Tensor[Slice] = u + if isinstance(index, Slice): + return True + # eg. Tensor[Tuple] = u + if isinstance(index, tuple): + if not index: + raise ValueError("Tensor's index cannot be empty.") + # eg. Tensor[Tuple(Slice...)] = u + if not isinstance(index[0], Slice): + raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0]))) + return True + # eg. Tensor[Tensor[dtype=bool]] = u + if index == mstype.tensor: + if element_type is None or element_type != mstype.bool_: + raise ValueError( + "The index of tensor should be a bool type tensor. \ + {} type is not supported yet.".format(element_type)) + return True + + raise ValueError("Index of type '{}' is not supported yet.".format(type(index))) @constexpr @@ -43,3 +79,84 @@ def error_msg(msg="", format_values=""): """ raise ValueError(msg.format(*format_values)) + +def slice_expand(input_slices, shape): + """ + Convert slice to indices. + + Inputs: + slices (List or Tuple(List, ...)): Slice tuple or slice. + shape (Tuple): The shape of a sensor is an integer element tuple. + + Outputs: + (List, List, List), This is expressed as (begins, ends, strides). + """ + begin = [] + end = [] + strides = [] + index = 0 + slices = None + # Slice or Tuple(Slice...) + if isinstance(input_slices, Slice): + slices = (input_slices,) + elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice): + slices = input_slices + else: + raise ValueError("Tensor's index type is not supported yet.") + + for s in slices: + start = 0 if (s.start is None) else s.start + stop = shape[index] if (s.end is None) else s.end + step = 1 if (s.step is None) else s.step + begin.append(start) + end.append(stop) + strides.append(step) + index += 1 + while index < len(shape): + begin.append(0) + end.append(shape[index]) + strides.append(1) + index += 1 + return begin, end, strides + +@constexpr +def slice2indices(input_slices, shape): + """ + Convert slice to indices. + + Inputs: + slices (List or Tuple(List, ...)): Slice tuple or slice. + shape (Tuple): The shape of a sensor is an integer element tuple. + + Outputs: + Tensor, the shape is (n, 1). + """ + begin, end, strides = slice_expand(input_slices, shape) + np_r = [] + for i, element in enumerate(shape): + s = begin[i] if (begin[i] >= 0) else (element + begin[i]) + e = end[i] if (end[i] >= 0) else (element + end[i]) + np_r.append(np.r_[s:e:strides[i]]) + # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) + np_ix = np.ix_(*np_r) + ravel = np.ravel_multi_index(np_ix, shape) + ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32) + return ravel + +@constexpr +def check_indices(indices_size, index): + if indices_size < 1: + raise ValueError("The tensor's index is unreasonable. index:{}".format(index)) + return indices_size + + +@constexpr +def check_indices_value_size(indices_size, value_size): + if value_size < 1: + raise ValueError("The value assigned to tensor cannot be empty.") + if value_size > 1: + if value_size != indices_size: + raise ValueError( + "The value given to tensor does not match the index size. \ + value size:{}, indics size:{}".format(value_size, indices_size)) + return value_size diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 31c96932c5edaf418da300febbb2dae422c28721..742ee57166fadd0e2d450465a5d7a75846e71b08 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -138,25 +138,23 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor): Outputs: Tensor, element type and shape is same as data. """ + result = None index_dtype = F.dtype(index) index_shape = F.shape(index) - is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) - if not is_bool: - return mult_util.error_msg( - "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) - data_shape = F.shape(data) - if index_shape != data_shape: - return mult_util.error_msg( - "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape)) - size = F.size(value_tensor) - if size != 1: - return mult_util.error_msg( - "When assign value is a tensor, its size should be 1, but current size is {}.", (size,)) - dtype = F.dtype(data) - u_cast = F.cast(value_tensor, dtype) - one_data = F.ones_like(data) - u = F.tensor_mul(one_data, u_cast) - return F.select(index, u, data) + check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) + if check_result: + data_shape = F.shape(data) + data_shape = mult_util.check_equal(data_shape, index_shape, + "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") + size = F.size(value_tensor) + size = mult_util.check_equal(1, size, + "When assign value is a tensor, its size should be {}, but current size is {}.") + dtype = F.dtype(data) + u_cast = F.cast(value_tensor, dtype) + one_data = F.ones_like(data) + u = F.tensor_mul(one_data, u_cast) + result = F.select(index, u, data) + return result @setitem.register("Tensor", "Tensor", "Number") @@ -179,16 +177,162 @@ def _tensor_setitem_by_tensor_v2(data, index, value): Outputs: Tensor, element type and shape is same as data. """ + result = None index_dtype = F.dtype(index) index_shape = F.shape(index) - is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) - if not is_bool: - return mult_util.error_msg( - "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) - shape = F.shape(data) - if index_shape != shape: - return mult_util.error_msg( - "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape)) - dtype = F.dtype(data) - u = F.fill(dtype, shape, value) - return F.select(index, u, data) + check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) + if check_result: + shape = F.shape(data) + shape = mult_util.check_equal( + shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") + dtype = F.dtype(data) + u = F.fill(dtype, shape, value) + result = F.select(index, u, data) + return result + + +@setitem.register("Tensor", "Slice", "Tensor") +def _tensor_setitem_with_slice_v3(data, input_slice, value): + """ + Tensor assignment. + + Note: + Syntax support: A[Slice] = U + Restraint condition: A is a Tensor + Slice like "1:3" + U is a Tensor(size=1) or Tensor(size>1) + + Inputs: + data (Tensor): Assigned tensor. + input_slice (Slice): Slice expression. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return _tensor_assgin_tensor(data, input_slice, value) + + +@setitem.register("Tensor", "Tuple", "Tensor") +def _tensor_setitem_with_slice_v4(data, input_slice, value): + """ + Tensor assignment. + + Note: + Syntax support: A[Slice] = U + Restraint condition: A is a Tensor + Slice like "1:3, ::, :4:-1" + U is a Tensor(size=1) or Tensor(size>1) + + Inputs: + data (Tensor): Assigned tensor. + input_slice (Tuple(Slice)): Slice expression. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return _tensor_assgin_tensor(data, input_slice, value) + + +def _tensor_assgin_tensor(data, input_slice, value): + """Given a tensor value assign to tensor by slice""" + # 1. condition + result = None + check_result = mult_util.check_tensor_setitem_index(input_slice) + if check_result: + data_shape = F.shape(data) + data_size = F.size(data) + data_dtype = F.dtype(data) + indices = mult_util.slice2indices(input_slice, data_shape) + indices_size = F.size(indices) + indices_size = mult_util.check_indices(indices_size, input_slice) + update = F.fill(data_dtype, (indices_size,), 1) + condition_1d = F.scatter_nd(indices, update, (data_size,)) + condition_1d = F.cast(condition_1d, mstype.bool_) + condition = F.reshape(condition_1d, data_shape) + # 2. u + value_fill = None + value_size = F.size(value) + + value_size = mult_util.check_indices_value_size(indices_size, value_size) + if value_size == 1: + value_fill = F.fill(data_dtype, (indices_size,), 1) + value = F.cast(value, data_dtype) + value_fill = F.tensor_mul(value_fill, value) + elif value_size > 1: + value_fill = F.reshape(value, (indices_size,)) + value_1d = F.scatter_nd(indices, value_fill, (data_size,)) + u = F.reshape(value_1d, data_shape) + # A[slice]= u -> A[B]=U -> select(B, U, A) + result = F.select(condition, u, data) + return result + + +@setitem.register("Tensor", "Slice", "Number") +def _tensor_setitem_with_slice_v1(data, input_slice, value): + """ + Tensor assignment. + + Note: + Syntax support: A[Slice] = u + Restraint condition: A is a Tensor. + Slice like "1:3" + u is a scalar + + Inputs: + data (Tensor): Assigned tensor. + input_slice (Slice): slice expression. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return _tensor_assgin_number(data, input_slice, value) + + +@setitem.register("Tensor", "Tuple", "Number") +def _tensor_setitem_with_slice_v2(data, input_slice, value): + """ + Tensor assignment. + + Note: + Syntax support: A[Slice] = u + Restraint condition: A is a Tensor. + Slice like "1:3, ::, :4:-1" + u is a scalar + + Inputs: + data (Tensor): Assigned tensor. + input_slice (Tuple(Slice)): slice expression. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return _tensor_assgin_number(data, input_slice, value) + + +def _tensor_assgin_number(data, input_slice, value): + """Given a scalar assign to tensor by slice""" + # 1. condition + check_result = mult_util.check_tensor_setitem_index(input_slice) + result = None + if check_result: + data_shape = F.shape(data) + data_size = F.size(data) + data_dtype = F.dtype(data) + indices = mult_util.slice2indices(input_slice, data_shape) + indices_size = F.size(indices) + indices_size = mult_util.check_indices(indices_size, input_slice) + update = F.fill(data_dtype, (indices_size,), 1) + condition_1d = F.scatter_nd(indices, update, (data_size,)) + condition_1d = F.cast(condition_1d, mstype.bool_) + condition = F.reshape(condition_1d, data_shape) + # 2. u + value_fill = F.fill(data_dtype, (indices_size,), value) + value_1d = F.scatter_nd(indices, value_fill, (data_size,)) + u = F.reshape(value_1d, data_shape) + # A[slice]= u -> A[B]=U -> select(B, U, A) + result = F.select(condition, u, data) + return result diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index c5b8752ae2a1d7ecd9a793217db5b668fde743ba..4135133e855af58d9cd0520cca34b8ae45c40994 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -68,6 +68,7 @@ tuple_to_array = P.TupleToArray() scalar_cast = P.ScalarCast() print_ = P.Print() expand_dims = P.ExpandDims() +scatter_nd = P.ScatterNd() tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive('tuple_getitem') diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index ddd1fb46a117cbc521b16c361bdac03aef451f1c..08ba143de82706af23ff4f7bfb579a60be5940f1 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -94,10 +94,101 @@ class NetWorkReduceToScalar(Cell): return ret +class TensorAssignWithSliceError1(Cell): + def __init__(self): + super(TensorAssignWithSliceError1, self).__init__() + + def construct(self, a, b): + a[1:3:-1,::] = b + return a + +class TensorAssignWithSliceError2(Cell): + def __init__(self): + super(TensorAssignWithSliceError2, self).__init__() + + def construct(self, a, b): + a[1:3:-1] = b + return a +class TensorAssignWithSlice2(Cell): + def __init__(self): + super(TensorAssignWithSlice2, self).__init__() + + def construct(self, a, b): + a[1:5] = b + a[3:4] = 5 + a[-1:1:-1] = b + a[-1:3:-1] = 5 + a[::] = b + a[::] = 9 + return a +class TensorAssignWithSlice(Cell): + def __init__(self): + super(TensorAssignWithSlice, self).__init__() + self.c = 2 + + def construct(self, a, b): + a[1:3,::] = b + a[2:3:,3:] = b + a[::] = b + a[::] = self.c + a[::,::] = b + a[::,::] = self.c + a[2:3:,0:, 4:1:-1] = b + a[2:3:,0:, 4:1:-1] = self.c + z = a + return z + +def test_tensor_assign_with_slice(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + net = TensorAssignWithSlice() + net2= TensorAssignWithSlice2() + net_e1 = TensorAssignWithSliceError1() + net_e2 = TensorAssignWithSliceError2() + a = np.arange(60).reshape(3,4,5) + b = Tensor([1]) + Ta = Tensor(a) + Tb= Tensor([1,3]) + Tc= Tensor([]) + t = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) + net(Ta, b) + net2(t, b) + # Error for A[Slice] = Number + # 1. A[Slice] = Number, Slice error + with pytest.raises(ValueError): + net_e2(t, 2) + + # Error for A[Slice] = U, U is a Tensor + # 1. A[Slice] = U, u.size is error + with pytest.raises(ValueError): + net2(t, Tb) + # 2. A[Slice] = U, U is empty + with pytest.raises(ValueError): + net2(t, Tc) + # 3. A[Slice] = U, U.size error + with pytest.raises(ValueError): + net2(t, Tb) + + # Error for A[Tuple(Slice...)] = Tensor + # 1. A[Tuple(Slice...)] = U, U is empty + with pytest.raises(ValueError): + net(Ta, Tc) + # 2. A[Tuple(Slice...)] = U, U.size error + with pytest.raises(ValueError): + net(Ta, Tb) + # 3. A[Tuple(Slice...)] = U, Slice error + with pytest.raises(ValueError): + net_e1(Ta, b) + + # Error for A[Tuple(Slice...)] = Number + # 1. A[Tuple(Slice...)] = Number, Slice error + with pytest.raises(ValueError): + net_e1(Ta, 2) + + class TensorAssignWithBoolTensorIndex(Cell): def __init__(self): super(TensorAssignWithBoolTensorIndex, self).__init__() - self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) + self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64) def construct(self, a, b, c, u_tensor, _scalar): a[c] = u_scalar @@ -119,6 +210,7 @@ class TensorAssignWithBoolTensorIndex2(Cell): def __init__(self): super(TensorAssignWithBoolTensorIndex2, self).__init__() self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) + self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64) def construct(self, a, u_tensor, _scalar): a[a > 8] = u_tensor @@ -139,7 +231,7 @@ class TensorAssignWithBoolTensorIndex2Error(Cell): return a -a = np.random.uniform(1, 10, [2, 3]) +a = np.random.uniform(1,10,[3,4,5]) b = a > 5 c = a < 3 Ta = Tensor(a) @@ -148,13 +240,13 @@ Tc = Tensor(c) Td = Tensor([True, True]) u_tensor = Tensor([1]) u_tensor_error = Tensor([1, 2]) +t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) u_scalar = 5 - def test_tensor_assign_bool_index(): net1 = TensorAssignWithBoolTensorIndex() net2 = TensorAssignWithBoolTensorIndex2() - + net1(Ta, Tb, Tc, u_tensor, u_scalar) net1(Ta, Tb, Tc, u_tensor, u_scalar) with pytest.raises(ValueError): net1(Ta, Td, Tc, u_tensor, u_scalar) @@ -180,8 +272,15 @@ def test_tensor_assign_bool_index(): with pytest.raises(AttributeError): net4(Ta, u_scalar) - test_cases = [ + ('TensorAssignWithSlice', { + 'block': TensorAssignWithSlice(), + 'desc_inputs': [Ta, u_tensor], + }), + ('TensorAssignWithSlice2', { + 'block': TensorAssignWithSlice2(), + 'desc_inputs': [t_1d, u_tensor], + }), ('TensorAssignWithBoolTensorIndex', { 'block': TensorAssignWithBoolTensorIndex(), 'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar],