提交 2c4cb49a 编写于 作者: B buxue

support interface 'all' and 'any' of tensor

上级 b0b4fa08
......@@ -31,6 +31,41 @@ trans = P.Transpose()
shape_ = P.Shape()
dtype_ = P.DType()
def all_(x, axis=(), keep_dims=False):
"""
Check all array elements along a given axis evaluate to True.
Args:
x (Tensor): A Tensor to be reduced.
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
Tensor, has the same data type as x.
"""
reduce_all = P.ReduceAll(keep_dims)
return reduce_all(x, axis)
def any_(x, axis=(), keep_dims=False):
"""
Check any array element along a given axis evaluate to True.
Args:
x (Tensor): A Tensor to be reduced.
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
Tensor, has the same data type as x.
"""
reduce_any = P.ReduceAny(keep_dims)
return reduce_any(x, axis)
def transpose(x):
"""Implementation of `transpose`."""
shape = F.shape(x)
......@@ -157,7 +192,6 @@ def check_is_const_int(x, op_name, arg_name):
return True
@constexpr
def check_is_tensor_bool_cond(shp):
"""check if tensor is a bool condition"""
......@@ -316,4 +350,5 @@ def to_array(x):
"""Implementation of `to_array`."""
return x.__ms_to_array__()
tensor_operator_registry.register('__bool__', tensor_bool)
......@@ -143,6 +143,8 @@ BuiltInTypeMap &GetMethodMap() {
}},
{kObjectTypeTensorType,
{
{"all", std::string("all_")}, // C.reduce_all
{"any", std::string("any_")}, // C.reduce_any
{"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub
{"__mul__", std::string("mul")}, // C.mul
......
......@@ -35,9 +35,11 @@ class Registry(UserDict):
new_args = list(args)
new_args.append(obj_str)
return self["vm_compare"](*new_args)
obj = wrap
else:
obj = self[obj_str]
return obj
tensor_operator_registry = Registry()
......@@ -27,7 +27,6 @@ np_types = (np.int8, np.int16, np.int32, np.int64,
np.float32, np.float64, np.bool_)
class Tensor(Tensor_):
"""
Tensor for data storage.
......@@ -205,6 +204,34 @@ class Tensor(Tensor_):
return "Unknown Tensor type!"
return str(self.asnumpy())
def all(self, axis=(), keep_dims=False):
"""
Check all array elements along a given axis evaluate to True.
Args:
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
Tensor, has the same data type as x.
"""
return tensor_operator_registry.get('all')(keep_dims)(self, axis)
def any(self, axis=(), keep_dims=False):
"""
Check any array element along a given axis evaluate to True.
Args:
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
Tensor, has the same data type as x.
"""
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
@property
def virtual_flag(self):
"""Mark tensor is virtual."""
......@@ -257,6 +284,7 @@ class IndexedSlices:
>>> values = Tensor([[1, 2]], dtype=ms.float32)
>>> Net((3, 2))(indices, values)
"""
def __init__(self, indices, values, dense_shape):
raise NotImplementedError
......@@ -297,5 +325,6 @@ class SparseTensor:
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> Net((3, 4))(indices, values)
"""
def __init__(self, indices, values, dense_shape):
raise NotImplementedError
......@@ -30,7 +30,6 @@ dtype = P.DType()
isconstant = Primitive('is_constant')
isconstant.add_prim_attr('const_value', True)
issubclass_ = P.IsSubClass()
isinstance_ = P.IsInstance()
fill = P.Fill()
......@@ -67,6 +66,7 @@ assign_sub = P.AssignSub()
assign = P.Assign()
square = P.Square()
sqrt = P.Sqrt()
scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor()
tuple_to_array = P.TupleToArray()
......@@ -83,7 +83,6 @@ partial = P.Partial()
# depend: mount a node to another node
depend = P.Depend()
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem')
list_getitem = Primitive('list_getitem')
......@@ -102,7 +101,6 @@ tuple_equal = Primitive("tuple_equal")
list_equal = Primitive("list_equal")
make_ref = Primitive("make_ref")
scalar_add = Primitive('scalar_add')
scalar_mul = Primitive('scalar_mul')
scalar_sub = Primitive('scalar_sub')
......@@ -154,7 +152,6 @@ shape_mul = Primitive("shape_mul")
# a primitive to compare between tuple.
stop_gradient = Primitive("stop_gradient")
make_indexed_slices = Primitive('MakeIndexedSlices')
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
......@@ -172,7 +169,9 @@ tensor_operator_registry.register('__truediv__', tensor_div)
tensor_operator_registry.register('__mod__', tensor_mod)
tensor_operator_registry.register('__pow__', tensor_pow)
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
#ms cannot support Tensor(True) compare
tensor_operator_registry.register('all', P.ReduceAll)
tensor_operator_registry.register('any', P.ReduceAny)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)
tensor_operator_registry.register('__neg__', neg_tensor)
......@@ -181,6 +180,6 @@ tensor_operator_registry.register('__le__', tensor_le)
tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('shape', shape)
#support GE backend for no compare operators
# support GE backend for no compare operators
tensor_operator_registry.register('vm_compare', BP.vm_compare)
tensor_operator_registry.register('cast', cast)
......@@ -1111,6 +1111,7 @@ class Mul(_MathBinaryOp):
>>> mul(input_x, input_y)
[4, 10, 18]
"""
def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test interface 'all' and 'any' of tensor """
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
def test_all_and_any_of_tensor_in_graph():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
all_ = x.all()
any_ = x.any()
all_0 = x.all(0, True)
any_0 = x.any(0, True)
return all_, any_, all_0, any_0
net = Net()
x = Tensor(np.array([[True, False, False], [True, False, False]]))
context.set_context(mode=context.GRAPH_MODE)
net(x)
def test_all_and_any_of_tensor_in_pynative():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
all_ = x.all()
any_ = x.any()
all_0 = x.all(0, True)
any_0 = x.any(0, True)
return all_, any_, all_0, any_0
net = Net()
x = Tensor(np.array([[True, False, True], [True, False, False]]))
context.set_context(mode=context.PYNATIVE_MODE)
ret = net(x)
assert ret[0].asnumpy() == np.array(False)
assert ret[1].asnumpy() == np.array(True)
assert ret[2].asnumpy().shape == np.array([[True, False, False]]).shape
assert (ret[2].asnumpy() == np.array([[True, False, False]])).all()
assert ret[3].shape == Tensor(np.array([[True, False, True]])).shape
assert (ret[3] == Tensor(np.array([[True, False, True]]))).all()
......@@ -194,7 +194,19 @@ def vm_impl_all(self):
def vm_impl(x, axis):
x = x.asnumpy()
out = vm.all(x, axis)
out = vm.all(x, axis, self.keep_dims)
return Tensor(out)
return vm_impl
@vm_impl_getters.register(P.ReduceAny)
def vm_impl_any(self):
"""Generate vm_impl function for Any"""
def vm_impl(x, axis):
x = x.asnumpy()
out = vm.any(x, axis, self.keep_dims)
return Tensor(out)
return vm_impl
......
......@@ -67,3 +67,5 @@ setattr(vm, "tanh", tanh)
setattr(vm, "sigmoid", sigmoid)
setattr(vm, 'maximum', maximum)
setattr(vm, 'minimum', minimum)
setattr(vm, 'all', all_)
setattr(vm, 'any', any_)
......@@ -840,3 +840,35 @@ def minimum(x, y):
numpy.ndarray, has the same type as x.
"""
return np.minimum(x, y)
def all_(x, axis=(), keep_dims=False):
"""
Check all array elements along a given axis evaluate to True.
Args:
x (numpy.ndarray): An array to be reduced.
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
numpy.ndarray, has the same type as x.
"""
axis = None if axis == () else axis
return np.all(x, axis, keepdims=keep_dims)
def any_(x, axis=(), keep_dims=False):
"""
Check any array element along a given axis evaluate to True.
Args:
x (numpy.ndarray): An array to be reduced.
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
numpy.ndarray, has the same type as x.
"""
axis = None if axis == () else axis
return np.any(x, axis, keepdims=keep_dims)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册