提交 33e427e1 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!721 Tensor assign by ellipsis

Merge pull request !721 from candanzg/tensor_assgin_ellipsis
......@@ -22,11 +22,11 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes, get_dataclass_methods,
get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol)
is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type',
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol',
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj',
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj']
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj', 'create_ellipsis_obj']
......@@ -29,7 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import _MindSporeFunction
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
from ..utils import Slice
from ..utils import Slice, Ellipsis_
# define return value
RET_SUCCESS = 0
......@@ -70,6 +70,11 @@ parse_expr_statement_white_list = (
"append",
)
def create_ellipsis_obj():
"""Create Slice object"""
return Ellipsis_()
def create_slice_obj(start, end, step):
"""Create Slice object"""
return Slice(start, end, step)
......
......@@ -110,3 +110,10 @@ class Slice:
start: int
end: int
step: int
@dataclass
class Ellipsis_:
"""
Ellipsis class
"""
......@@ -80,6 +80,7 @@ const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
// define the common name
const char NAMED_PRIMITIVE_ITER[] = "iter";
......
......@@ -298,6 +298,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
} else if (abs_base->isa<AbstractRef>()) {
auto value = abs_base->cast<AbstractRefPtr>()->ref();
dic = ConvertAbstractToPython(value);
} else if (abs_base->isa<AbstractEllipsis>()) {
auto arg_slice = dyn_cast<AbstractEllipsis>(abs_base);
std::vector<int> shape;
dic["shape"] = shape;
dic["dtype"] = arg_slice->BuildType();
dic["value"] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractTuple>()) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
size_t len = arg_tuple->size();
......
......@@ -98,6 +98,8 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
i++;
}
ret = rets;
} else if (value->isa<EllipsisObj>()) {
ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_ELLIPSIS);
} else if (value->isa<ValueSlice>()) {
auto slice = value->cast<ValueSlicePtr>();
auto start = ValuePtrToPyData(slice->start());
......
......@@ -20,7 +20,7 @@ import numpy as np
from ...primitive import constexpr
from ....common.tensor import Tensor
from ....common import dtype as mstype
from ...._extends.utils import Slice
from ...._extends.utils import Slice, Ellipsis_
@constexpr
def check_equal(param1, param2, msg="{},{}"):
......@@ -29,31 +29,40 @@ def check_equal(param1, param2, msg="{},{}"):
raise ValueError(msg.format(param1, param2))
return param1
@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
if data_shape == value_shape or data_size == value_size or value_size == 1:
return True
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))
@constexpr
def check_tensor_setitem_index(index, element_type=None):
"""Checks tuple index type of tensor assignment."""
if index is None:
raise ValueError("Tensor's index cannot be None.")
raise IndexError("Tensor's index cannot be None.")
# eg. Tensor[Slice] = u
if isinstance(index, Slice):
return True
# eg. Tensor[tuple] = u
if isinstance(index, tuple):
if not index:
raise ValueError("Tensor's index cannot be empty.")
raise IndexError("Tensor's index cannot be empty.")
# eg. Tensor[tuple(Slice...)] = u
if isinstance(index[0], (Slice, int)):
if isinstance(index[0], (Slice, Ellipsis_, int)):
return True
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
# eg. Tensor[Tensor[dtype=bool]] = u
if index == mstype.tensor:
if element_type is None or element_type != mstype.bool_:
raise ValueError(
"The index of tensor should be a bool type tensor. \
{} type is not supported yet.".format(element_type))
raise TypeError(
"The index of tensor should be a bool type tensor. "
"{} type is not supported yet.".format(element_type))
return True
raise ValueError("Index of type '{}' is not supported yet.".format(type(index)))
raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))
@constexpr
......@@ -90,10 +99,18 @@ def slice_expand(input_slices, shape):
# Slice or tuple(Slice...)
if isinstance(input_slices, Slice):
slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice):
slices = input_slices
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
is_have_ellipsis = False
for _, element in enumerate(input_slices):
if isinstance(element, Ellipsis_):
is_have_ellipsis = True
break
if is_have_ellipsis:
slices = ellipsis2slice(input_slices, shape)
else:
slices = input_slices
else:
raise ValueError("Tensor's index type is not supported yet.")
raise IndexError("Tensor's index type is not supported yet.")
for s in slices:
start = 0 if (s.start is None) else s.start
......@@ -111,6 +128,26 @@ def slice_expand(input_slices, shape):
return begin, end, strides
def ellipsis2slice(input_, shape):
"""Converts ellipsis to slice."""
input_slice = input_
result = []
if isinstance(input_, Ellipsis_):
input_slice = (input_,)
ell_count = 0
for _, element in enumerate(input_slice):
if not isinstance(element, Ellipsis_):
result.append(element)
continue
ell_count += 1
if ell_count > 1:
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
"but it is currently {}".format(input_slice))
for _ in range(len(shape) - len(input_slice) + 1):
result.append(Slice(None, None, None))
return tuple(result)
@constexpr
def slice2indices(input_slices, shape):
"""
......@@ -139,7 +176,7 @@ def slice2indices(input_slices, shape):
def check_indices(indices_size, index):
"""Checks indices whether is empty."""
if indices_size < 1:
raise ValueError("The tensor's index is unreasonable. index:{}".format(index))
raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
return indices_size
......@@ -151,8 +188,8 @@ def check_indices_value_size(indices_size, value_size):
if value_size > 1:
if value_size != indices_size:
raise ValueError(
"The value given to tensor does not match the index size. \
value size:{}, indics size:{}".format(value_size, indices_size))
"The value given to tensor does not match the index size,"
" value size:{}, indics size:{}".format(value_size, indices_size))
return value_size
@constexpr
......@@ -168,8 +205,11 @@ def integer_to_indices(index, shape):
def tuple_element_is_slice(indexs):
"""Judges tuple element type."""
if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], Slice):
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, Slice):
return False
return True
return False
......@@ -177,7 +217,10 @@ def tuple_element_is_slice(indexs):
def tuple_element_is_int(indexs):
"""Judges tuple element type."""
if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], int):
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, int):
return False
return True
return False
......@@ -254,10 +254,10 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(data_dtype, (indices_size,), 1)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
......@@ -336,10 +336,10 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(data_dtype, (indices_size,), 1)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
......@@ -360,3 +360,32 @@ def _tensor_setitem_with_int_v2(data, index, value):
data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
@setitem.register("Tensor", "Ellipsis", "Number")
def _tensor_setitem_with_ellipsis_v1(data, index, value):
"""Syntax: A[...] = number."""
data_shape = F.shape(data)
data_dtype = F.dtype(data)
return F.fill(data_dtype, data_shape, value)
@setitem.register("Tensor", "Ellipsis", "Tensor")
def _tensor_setitem_with_ellipsis_v2(data, index, value):
"""Syntax: A[...] = Tensor."""
result = None
data_shape = F.shape(data)
data_dtype = F.dtype(data)
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
result = F.cast(result, data_dtype)
elif value_size == 1:
param1 = F.fill(data_dtype, data_shape, 1)
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result
......@@ -103,6 +103,7 @@ class TensorAssignWithSliceError1(Cell):
a[1:3:-1,::] = b
return a
class TensorAssignWithSliceError2(Cell):
def __init__(self):
super(TensorAssignWithSliceError2, self).__init__()
......@@ -110,24 +111,29 @@ class TensorAssignWithSliceError2(Cell):
def construct(self, a, b):
a[1:3:-1] = b
return a
class TensorAssignWithSlice2(Cell):
def __init__(self):
super(TensorAssignWithSlice2, self).__init__()
def construct(self, a, b):
def construct(self, a, b, ck):
a[1:5] = b
a[3:4] = 5
a[-1:1:-1] = b
a[-1:3:-1] = 5
a[::] = b
a[::] = 9
return a
z = a + ck
return z
class TensorAssignWithSlice(Cell):
def __init__(self):
super(TensorAssignWithSlice, self).__init__()
self.c = 2
def construct(self, a, b):
def construct(self, a, b, ck):
a[1:3,::] = b
a[2:3:,3:] = b
a[::] = b
......@@ -136,9 +142,10 @@ class TensorAssignWithSlice(Cell):
a[::,::] = self.c
a[2:3:,0:, 4:1:-1] = b
a[2:3:,0:, 4:1:-1] = self.c
z = a
z = a + ck
return z
def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice()
......@@ -146,95 +153,145 @@ def test_tensor_assign():
net_e1 = TensorAssignWithSliceError1()
net_e2 = TensorAssignWithSliceError2()
a = np.arange(60).reshape(3,4,5)
b = Tensor([1])
Ta = Tensor(a)
Ta4d = Tensor(a.reshape(1,3,4,5))
Tb= Tensor([1,3])
Tc= Tensor([])
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
net(Ta, b)
net2(t, b)
ck = np.arange(60).reshape(3,4,5)
b = Tensor([1], dtype=mstype.float32)
Ta = Tensor(a, dtype=mstype.float32)
Tck = Tensor(ck, dtype=mstype.float32)
Ta4d = Tensor(a.reshape(1,3,4,5), dtype=mstype.float32)
Ta4d_ck = Tensor(ck.reshape(1,3,4,5), dtype=mstype.float32)
Tb= Tensor([1,3], dtype=mstype.float32)
Tc= Tensor([], dtype=mstype.float32)
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
net(Ta, b, Tck)
net2(t, b, tck)
# Error for A[Slice] = Number
# 1. A[Slice] = Number, Slice error
with pytest.raises(ValueError):
with pytest.raises(IndexError):
net_e2(t, 2)
# Error for A[Slice] = U, U is a Tensor
# 1. A[Slice] = U, u.size is error
with pytest.raises(ValueError):
net2(t, Tb)
net2(t, Tb, tck)
# 2. A[Slice] = U, U is empty
with pytest.raises(ValueError):
net2(t, Tc)
net2(t, Tc, tck)
# 3. A[Slice] = U, U.size error
with pytest.raises(ValueError):
net2(t, Tb)
net2(t, Tb, tck)
# Error for A[Tuple(Slice...)] = Tensor
# 1. A[Tuple(Slice...)] = U, U is empty
with pytest.raises(ValueError):
net(Ta, Tc)
net(Ta, Tc, Tck)
# 2. A[Tuple(Slice...)] = U, U.size error
with pytest.raises(ValueError):
net(Ta, Tb)
net(Ta, Tb, Tck)
# 3. A[Tuple(Slice...)] = U, Slice error
with pytest.raises(ValueError):
with pytest.raises(IndexError):
net_e1(Ta, b)
# Error for A[Tuple(Slice...)] = Number
# 1. A[Tuple(Slice...)] = Number, Slice error
with pytest.raises(ValueError):
with pytest.raises(IndexError):
net_e1(Ta, 2)
net = TensorAssignWithInteger()
# Error for A[Number] = scalar/Tensor
# 1. A[Number] = U, U is a Tensor, u.size not match
with pytest.raises(ValueError):
net(Ta, Tb)
net(Ta, Tb, Tck)
with pytest.raises(ValueError):
net(Ta, Tc)
net(Ta, Tc, Tck)
# 2. A[Number] = U, the number index error
with pytest.raises(IndexError):
net(Ta4d, b)
net(Ta4d, b, Ta4d_ck)
# Error for A[(n,m)] = scalar/Tensor
# 1. A[(n,m)] = U, U is a tensor. u.size not match
net = TensorAssignWithTupleInteger()
with pytest.raises(ValueError):
net(Ta, Tc)
net(Ta, Tc, Tck)
with pytest.raises(ValueError):
net(Ta, Tb)
net(Ta, Tb, Tck)
# 2. A[(n,m)] = U, the number index error
with pytest.raises(IndexError):
net(Ta4d, b)
net(Ta4d, b, Ta4d_ck)
#Error for A[...] = U or A[1:, ...] = u
#1. A[...] = scalar/tensor
net = TensorAssignWithEllipsis()
net(Ta, Ta4d)
with pytest.raises(ValueError):
net(Ta, Tc)
with pytest.raises(ValueError):
net(Ta, Tb)
#2. A[::, 1:, ...] = scalar/tensor
net = TensorAssignWithTupleEllipsis()
net(Ta, b)
with pytest.raises(ValueError):
net(Ta, Tc)
with pytest.raises(ValueError):
net(Ta, Tb)
class TensorAssignWithTupleEllipsis2(Cell):
def __init__(self):
super(TensorAssignWithTupleEllipsis2, self).__init__()
def construct(self, a, b):
a[1:, ..., ::] = b
return a
class TensorAssignWithTupleEllipsis(Cell):
def __init__(self):
super(TensorAssignWithTupleEllipsis, self).__init__()
def construct(self, a, b):
a[:2, ...] = 1
a[1:, ...] = b
return a
class TensorAssignWithEllipsis(Cell):
def __init__(self):
super(TensorAssignWithEllipsis, self).__init__()
def construct(self, a, b):
a[...] = 1
a[...] = b
return a
class TensorAssignWithInteger(Cell):
def __init__(self):
super(TensorAssignWithInteger, self).__init__()
def construct(self, a, b):
def construct(self, a, b, ck):
a[1] = 1
a[0] = b
return a
z = a + ck
return z
class TensorAssignWithTupleInteger(Cell):
def __init__(self):
super(TensorAssignWithTupleInteger, self).__init__()
def construct(self, a, b):
def construct(self, a, b, ck):
a[(1)] = 1
a[(1)] = b
a[(1,1)] = b
a[(1,1)] = 1
return a
z = a + ck
return z
class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex, self).__init__()
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64)
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float32)
self.u_scalar = 5
def construct(self, a, b, c, u_tensor, _scalar):
a[c] = u_scalar
def construct(self, a, b, c, u_tensor):
a[c] = self.u_scalar
a[b] = u_tensor
z = a + self.t
return z
......@@ -252,15 +309,16 @@ class TensorAssignWithBoolTensorIndexError(Cell):
class TensorAssignWithBoolTensorIndex2(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex2, self).__init__()
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64)
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32)
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float32)
self.u_scalar = 5
def construct(self, a, u_tensor, _scalar):
def construct(self, a, u_tensor):
a[a > 8] = u_tensor
a[a >= 6] = u_scalar
a[a < 3] = u_scalar
a[a >= 6] = self.u_scalar
a[a < 3] = self.u_scalar
a[a <= 5] = u_tensor
a[a == 5] = u_scalar
a[a == 5] = self.u_scalar
z = a + self.t
return z
......@@ -274,36 +332,41 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
return a
a = np.random.uniform(1,10,[3,4,5])
a = np.arange(60).reshape(3, 4, 5)
ck = np.arange(60).reshape(3, 4, 5)
a4 = np.arange(60).reshape(3, 2, 2, 5)
b = a > 5
c = a < 3
Ta = Tensor(a)
Ta = Tensor(a, dtype=mstype.float32)
Tck = Tensor(ck, dtype=mstype.float32)
Ta4 = Tensor(a4, dtype=mstype.float32)
Tb = Tensor(b)
Tc = Tensor(c)
Td = Tensor([True, True])
u_tensor = Tensor([1])
u_tensor_error = Tensor([1, 2])
t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
u_tensor = Tensor([1], dtype=mstype.float32)
u_tensor_error = Tensor([1, 2], dtype=mstype.float32)
t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
tck_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
u_scalar = 5
def test_tensor_assign_bool_index():
net1 = TensorAssignWithBoolTensorIndex()
net2 = TensorAssignWithBoolTensorIndex2()
net1(Ta, Tb, Tc, u_tensor, u_scalar)
net1(Ta, Tb, Tc, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, Td, Tc, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, u_tensor, Tc, u_tensor, u_scalar)
net1(Ta, Tb, Tc, u_tensor)
net1(Ta, Tb, Tc, u_tensor)
with pytest.raises(ValueError):
net1(Ta, Tb, Td, u_tensor, u_scalar)
net1(Ta, Td, Tc, u_tensor)
with pytest.raises(TypeError):
net1(Ta, u_tensor, Tc, u_tensor)
with pytest.raises(ValueError):
net1(Ta, Tb, Ta, u_tensor, u_scalar)
net1(Ta, Tb, Td, u_tensor)
with pytest.raises(TypeError):
net1(Ta, Tb, Ta, u_tensor)
with pytest.raises(ValueError):
net1(Ta, Tb, Tc, u_tensor_error, u_scalar)
net1(Ta, Tb, Tc, u_tensor_error)
# net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
with pytest.raises(ValueError):
net2(Ta, u_tensor_error, u_scalar)
net2(Ta, u_tensor_error)
net3 = TensorAssignWithBoolTensorIndexError()
with pytest.raises(AttributeError):
net3(Ta, Tb, Tc, u_tensor)
......@@ -316,29 +379,41 @@ def test_tensor_assign_bool_index():
net4(Ta, u_scalar)
test_cases = [
('TensorAssignWithTupleEllipsis2', {
'block': TensorAssignWithTupleEllipsis2(),
'desc_inputs': [Ta4, u_tensor],
}),
('TensorAssignWithTupleEllipsis', {
'block': TensorAssignWithTupleEllipsis(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithEllipsis', {
'block': TensorAssignWithEllipsis(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithTupleInteger', {
'block': TensorAssignWithTupleInteger(),
'desc_inputs': [Ta, u_tensor],
'desc_inputs': [Ta, u_tensor, Tck],
}),
('TensorAssignWithInteger', {
'block': TensorAssignWithInteger(),
'desc_inputs': [Ta, u_tensor],
'desc_inputs': [Ta, u_tensor, Tck],
}),
('TensorAssignWithSlice', {
'block': TensorAssignWithSlice(),
'desc_inputs': [Ta, u_tensor],
'desc_inputs': [Ta, u_tensor, Tck],
}),
('TensorAssignWithSlice2', {
'block': TensorAssignWithSlice2(),
'desc_inputs': [t_1d, u_tensor],
'desc_inputs': [t_1d, u_tensor, tck_1d],
}),
('TensorAssignWithBoolTensorIndex', {
'block': TensorAssignWithBoolTensorIndex(),
'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar],
'desc_inputs': [Ta, Tb, Tc, u_tensor],
}),
('TensorAssignWithBoolTensorIndex2', {
'block': TensorAssignWithBoolTensorIndex2(),
'desc_inputs': [Ta, u_tensor, u_scalar],
'desc_inputs': [Ta, u_tensor],
}),
('SlicePositive', {
'block': NetWorkSlicePositive(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册