提交 3ba31ec1 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!419 Tensor assign with bool Tensor

Merge pull request !419 from candanzg/tensor_assign_bool_index
...@@ -83,6 +83,7 @@ convert_object_map = { ...@@ -83,6 +83,7 @@ convert_object_map = {
T.mul: multitype_ops.mul, T.mul: multitype_ops.mul,
T.truediv: multitype_ops.div, T.truediv: multitype_ops.div,
T.getitem: multitype_ops.getitem, T.getitem: multitype_ops.getitem,
T.setitem: multitype_ops.setitem,
T.floordiv: multitype_ops.floordiv, T.floordiv: multitype_ops.floordiv,
T.mod: multitype_ops.mod, T.mod: multitype_ops.mod,
T.pow: multitype_ops.pow_, T.pow: multitype_ops.pow_,
...@@ -118,7 +119,6 @@ convert_object_map = { ...@@ -118,7 +119,6 @@ convert_object_map = {
T.iter: M.ms_iter, T.iter: M.ms_iter,
T.next: M.ms_next, T.next: M.ms_next,
T.hasnext: M.hasnext, T.hasnext: M.hasnext,
T.setitem: M.setitem,
T.make_tuple: F.make_tuple, T.make_tuple: F.make_tuple,
T.make_dict: F.make_dict, T.make_dict: F.make_dict,
......
...@@ -23,6 +23,7 @@ from .pow_impl import pow_ ...@@ -23,6 +23,7 @@ from .pow_impl import pow_
from .floordiv_impl import floordiv from .floordiv_impl import floordiv
from .mod_impl import mod from .mod_impl import mod
from .getitem_impl import getitem from .getitem_impl import getitem
from .setitem_impl import setitem
from .zeros_like_impl import zeros_like from .zeros_like_impl import zeros_like
from .ones_like_impl import ones_like from .ones_like_impl import ones_like
from .equal_impl import equal from .equal_impl import equal
...@@ -55,6 +56,7 @@ __all__ = [ ...@@ -55,6 +56,7 @@ __all__ = [
'greater_equal', 'greater_equal',
'negative', 'negative',
'getitem', 'getitem',
'setitem',
'logical_and', 'logical_and',
'logical_or', 'logical_or',
'logical_not' 'logical_not'
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""constexpr util"""
from ...primitive import constexpr
@constexpr
def is_same_type(inst, type_):
"""
Check whether an object is an instance of a target type.
Inputs:
inst (mindspore.dtype): Inspected type.
type_ (mindspore.dtype): Target type.
Outputs:
bool, the check result.
"""
return inst == type_
@constexpr
def error_msg(msg="", format_values=""):
"""
Used to throw exception information.
Inputs:
msg (str): information content.
"""
raise ValueError(msg.format(*format_values))
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implementation for setitem."""
from ...composite import base
from ....common import dtype as mstype
from ... import functional as F
from . import _multitype_ops_util as mult_util
setitem = base.MultitypeFuncGraph('setitem')
@setitem.register("List", "Number", "String")
def _list_setitem_with_string(data, number_index, value):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (String): Value given.
Outputs:
List, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
@setitem.register("List", "Number", "Number")
def _list_setitem_with_number(data, number_index, value):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (Number): Value given.
Outputs:
List, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
@setitem.register("List", "Number", "Tensor")
def _list_setitem_with_Tensor(data, number_index, value):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (Tensor): Value given.
Outputs:
List, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
@setitem.register("List", "Number", "List")
def _list_setitem_with_List(data, number_index, value):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (List): Value given.
Outputs:
List, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
@setitem.register("Dictionary", "String", "Tensor")
def _dict_setitem_with_tensor(data, key, value):
"""
Assign value to dictionary.
Inputs:
data (Dictionary): Data of type dict.
key (str): Key of the data.
value (Tensor): Value given.
Outputs:
Dict, type is as same as the element type of data.
"""
return F.dict_setitem(data, key, value)
@setitem.register("Dictionary", "String", "Number")
def _dict_setitem_with_number(data, key, value):
"""
Assign value to dictionary.
Inputs:
data (Dictionary): Data of type dict.
key (str): Key of the data.
value (Number): Value given.
Outputs:
Dict, type is as same as the element type of data.
"""
return F.dict_setitem(data, key, value)
@setitem.register("Tensor", "Tensor", "Tensor")
def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
"""
Tensor assignment.
Note:
Syntax support: A[B] = U and A[A>n] = U.
Restraint condition: 1) A, U is a Tensor, and B is a bool Tensor.
2) A.shape == B.shape
3) U.size == 1
4) n is a number
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value_tensor (Tensor): Tensor with size 1.
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
index_shape = F.shape(index)
is_bool = mult_util.is_same_type(index_dtype, mstype.bool_)
if not is_bool:
return mult_util.error_msg(
"The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,))
data_shape = F.shape(data)
if index_shape != data_shape:
return mult_util.error_msg(
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape))
size = F.size(value_tensor)
if size != 1:
return mult_util.error_msg(
"When assign value is a tensor, its size should be 1, but current size is {}.", (size,))
dtype = F.dtype(data)
u_cast = F.cast(value_tensor, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
return F.select(index, u, data)
@setitem.register("Tensor", "Tensor", "Number")
def _tensor_setitem_by_tensor_v2(data, index, value):
"""
Tensor assignment.
Note:
Syntax support: A[B] = u and A[A>n] = u.
Restraint condition: 1) A is a Tensor, and B is a bool Tensor.
2) A.shape == B.shape
3) u is a scalar
4) n is a number
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value_tensor (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
index_shape = F.shape(index)
is_bool = mult_util.is_same_type(index_dtype, mstype.bool_)
if not is_bool:
return mult_util.error_msg(
"The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,))
shape = F.shape(data)
if index_shape != shape:
return mult_util.error_msg(
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape))
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
return F.select(index, u, data)
...@@ -31,6 +31,9 @@ dtype = P.DType() ...@@ -31,6 +31,9 @@ dtype = P.DType()
issubclass_ = P.IsSubClass() issubclass_ = P.IsSubClass()
isinstance_ = P.IsInstance() isinstance_ = P.IsInstance()
fill = P.Fill() fill = P.Fill()
select = P.Select()
size = P.Size()
ones_like = P.OnesLike()
shape = P.Shape() shape = P.Shape()
rank = P.Rank() rank = P.Rank()
reshape = P.Reshape() reshape = P.Reshape()
...@@ -68,7 +71,9 @@ scalar_cast = P.ScalarCast() ...@@ -68,7 +71,9 @@ scalar_cast = P.ScalarCast()
tuple_setitem = Primitive('tuple_setitem') tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')
list_getitem = Primitive('list_getitem') list_getitem = Primitive('list_getitem')
list_setitem = Primitive('list_setitem')
dict_getitem = Primitive('dict_getitem') dict_getitem = Primitive('dict_getitem')
dict_setitem = Primitive('dict_setitem')
tuple_div = Primitive("tuple_div") tuple_div = Primitive("tuple_div")
tuple_len = Primitive("tuple_len") tuple_len = Primitive("tuple_len")
tuple_reversed = Primitive("tuple_reversed") tuple_reversed = Primitive("tuple_reversed")
......
...@@ -18,6 +18,7 @@ import pytest ...@@ -18,6 +18,7 @@ import pytest
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore import dtype as mstype
from mindspore.nn import Cell from mindspore.nn import Cell
from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.mindspore_test import mindspore_test
...@@ -79,7 +80,102 @@ class NetWorkReduceToScalar(Cell): ...@@ -79,7 +80,102 @@ class NetWorkReduceToScalar(Cell):
return ret return ret
class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex, self).__init__()
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64)
def construct(self, a, b, c, u_tensor, _scalar):
a[c] = u_scalar
a[b] = u_tensor
z = a + self.t
return z
class TensorAssignWithBoolTensorIndexError(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndexError, self).__init__()
def construct(self, a, b, c, u_tensor):
a[b][c] = u_tensor
return a
class TensorAssignWithBoolTensorIndex2(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex2, self).__init__()
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64)
def construct(self, a, u_tensor, _scalar):
a[a>8] = u_tensor
a[a>=6] = u_scalar
a[a<3] = u_scalar
a[a<=5] = u_tensor
a[a==5] = u_scalar
z = a + self.t
return z
class TensorAssignWithBoolTensorIndex2Error(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
def construct(self, a, u_tensor):
a[a>8][a>5] = u_tensor
return a
a = np.random.uniform(1,10,[2,3])
b = a > 5
c = a < 3
Ta = Tensor(a)
Tb = Tensor(b)
Tc = Tensor(c)
Td = Tensor([True, True])
u_tensor = Tensor([1])
u_tensor_error = Tensor([1, 2])
u_scalar = 5
def test_tensor_assign_bool_index():
net1 = TensorAssignWithBoolTensorIndex()
net2 = TensorAssignWithBoolTensorIndex2()
net1(Ta, Tb, Tc, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, Td, Tc, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, u_tensor, Tc, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, Tb, Td, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, Tb, Ta, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, Tb, Tc, u_tensor_error, u_scalar)
#net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
with pytest.raises(ValueError):
net2(Ta, u_tensor_error, u_scalar)
net3 = TensorAssignWithBoolTensorIndexError()
with pytest.raises(AttributeError):
net3(Ta, Tb, Tc, u_tensor)
with pytest.raises(AttributeError):
net3(Ta, Tb, Tc, u_scalar)
net4 = TensorAssignWithBoolTensorIndex2Error()
with pytest.raises(AttributeError):
net4(Ta, u_tensor)
with pytest.raises(AttributeError):
net4(Ta, u_scalar)
test_cases = [ test_cases = [
('TensorAssignWithBoolTensorIndex', {
'block': TensorAssignWithBoolTensorIndex(),
'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar],
}),
('TensorAssignWithBoolTensorIndex2', {
'block': TensorAssignWithBoolTensorIndex2(),
'desc_inputs': [Ta, u_tensor, u_scalar],
}),
('SlicePositive', { ('SlicePositive', {
'block': NetWorkSlicePositive(), 'block': NetWorkSlicePositive(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册