提交 461d8e3a 编写于 作者: H huangdongrun

add comparison ops

fix pylint

use scalar_lt primitive directly

fix review
上级 9cb219c6
......@@ -92,16 +92,16 @@ convert_object_map = {
T.and_: multitype_ops.logical_and,
T.or_: multitype_ops.logical_or,
T.xor: NO_IMPLEMENT,
T.pos: F.scalar_uadd,
T.pos: multitype_ops.uadd,
T.neg: multitype_ops.negative,
T.invert: NO_IMPLEMENT,
T.not_: F.bool_not,
T.not_: multitype_ops.logical_not,
T.eq: multitype_ops.equal,
T.ne: F.scalar_ne,
T.ne: multitype_ops.not_equal,
T.lt: multitype_ops.less,
T.gt: F.scalar_gt,
T.gt: multitype_ops.greater,
T.le: multitype_ops.less_equal,
T.ge: F.scalar_ge,
T.ge: multitype_ops.greater_equal,
T.is_: F.is_,
T.is_not: F.is_not,
T.contains: NO_IMPLEMENT,
......
......@@ -23,23 +23,33 @@ from .getitem_impl import getitem
from .zeros_like_impl import zeros_like
from .ones_like_impl import ones_like
from .equal_impl import equal
from .not_equal_impl import not_equal
from .less_impl import less
from .less_equal_impl import less_equal
from .greater_impl import greater
from .greater_equal_impl import greater_equal
from .negative_impl import negative
from .logical_and_impl import logical_and
from .logical_or_impl import logical_or
from .logic_not_impl import logical_not
from .uadd_impl import uadd
__all__ = [
'add',
'sub',
'mul',
'div',
'uadd',
'zeros_like',
'ones_like',
'equal',
'not_equal',
'less',
'less_equal',
'greater',
'greater_equal',
'negative',
'getitem',
'logical_and',
'logical_or'
'logical_or',
'logical_not'
]
......@@ -190,7 +190,8 @@ def _none_equal_tuple(x, y):
"""
return False
@equal.register("Tensor", "Number")
@equal.register("Number", "Tensor")
@equal.register("Tensor", "Tensor")
def _tensor_equal_tensor(x, y):
"""
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""greater_equal_impl"""
from mindspore.ops.composite import base
from mindspore.ops import functional as F
# greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type
# using ".register" decorator
greater_equal = base.MultitypeFuncGraph("greater_equal")
@greater_equal.register("Number", "Number")
def _greater_equal_scala(x, y):
"""
Determine whether x is greater equal than y
Args:
x(Number): Number.
y(Number): Number.
Returns:
bool, if x >= y return true, x < y return false.
"""
return F.scalar_ge(x, y)
@greater_equal.register("Tensor", "Number")
@greater_equal.register("Number", "Tensor")
@greater_equal.register("Tensor", "Tensor")
def _greater_equal_tensor(x, y):
"""
Determine whether tensor x is greater equal than tensor y elementwise
Args:
x(Tensor): Tensor.
y(Tensor): Tensor.
Returns:
Tensor, return value by operator P.GreaterEqual.
"""
return F.tensor_ge(x, y)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Ungreater required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""equal_impl"""
from mindspore.ops.composite import base
from mindspore.ops import functional as F
# greater is a metafuncgraph object which will determine if two objects are greater according to input type
# using ".register" decorator
greater = base.MultitypeFuncGraph("greater")
@greater.register("Number", "Number")
def _greater_scala(x, y):
"""
Determine whether two numbers are greater.
Args:
x(Number): Number.
y(Number): Number.
Returns:
bool, if x > y return true, x <= y return false.
"""
return F.scalar_gt(x, y)
@greater.register("Tensor", "Number")
@greater.register("Number", "Tensor")
@greater.register("Tensor", "Tensor")
def _greater_tensor(x, y):
"""
Determine whether two tensor are greater by element.
Args:
x(Tensor): Tensor.
y(Tensor): Tensor.
Returns:
tensor, return operation of x and y by P.Greater
"""
return F.tensor_gt(x, y)
......@@ -36,7 +36,8 @@ def _less_equal_scala(x, y):
"""
return F.scalar_le(x, y)
@less_equal.register("Tensor", "Number")
@less_equal.register("Number", "Tensor")
@less_equal.register("Tensor", "Tensor")
def _less_equal_tensor(x, y):
"""
......
......@@ -36,7 +36,8 @@ def _less_scala(x, y):
"""
return F.scalar_lt(x, y)
@less.register("Tensor", "Number")
@less.register("Number", "Tensor")
@less.register("Tensor", "Tensor")
def _less_tensor(x, y):
"""
......@@ -47,6 +48,6 @@ def _less_tensor(x, y):
y(Tensor): Tensor.
Returns:
bool, if x and y are less elements by element return true, else return false.
Tensor, return value of x and y by operation P.Less()
"""
return F.tensor_lt(x, y)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""logical_not_impl"""
from mindspore.ops.composite import base
from mindspore.ops import functional as F
# logical_not is a metagraph object which will generate function according to input type
# using ".register" decorator
logical_not = base.MultitypeFuncGraph("logical_not")
@logical_not.register("Number")
def _logical_not_scala(x):
"""
Return logical not operation result of x
Args:
x(Number): Number.
Returns:
bool, Return logical not operation result of x
"""
return F.bool_not(x.__bool__())
@logical_not.register("Tensor")
def _logical_not_tensor(x):
"""
Return logical not operation result of x
Args:
x(Tensor): Tensor.
Returns:
Tensor, Return logical not operation result of x
"""
return F.logical_not(x)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""not_equal_impl"""
from ...composite import base
from ... import functional as F
not_equal = base.MultitypeFuncGraph("not_equal")
"""
not_equal is a metafuncgraph object which will determine if two objects are not_equal according to input type
using ".register" decorator
"""
@not_equal.register("Number", "Number")
def _not_equal_scalar(x, y):
"""
Determine if two numbers is not equal.
Args:
x (Number): x
y (NUmber): y
Returns:
bool, if x != y return true, x == y return false.
"""
return not F.scalar_eq(x, y)
@not_equal.register("String", "String")
def _not_equal_string(x, y):
"""
Determine if two strings are not equal.
Args:
x: str
y: str
Returns:
bool, if x != y return true, x == y return false.
"""
return not F.string_eq(x, y)
@not_equal.register("String", "None")
def _string_not_equal_none(x, y):
"""
Determine if string not equals none.
Args:
x: str.
y: None.
Returns:
bool, return True.
"""
return True
@not_equal.register("None", "String")
def _none_not_equal_string(x, y):
"""
Determine if string not equals none.
Args:
x: None.
y: str.
Returns:
bool, return True.
"""
return True
@not_equal.register("None", "None")
def _none_not_equal_none(x, y):
"""
Determine if none not equals none.
Args:
x: None.
y: None.
Returns:
bool, return False.
"""
return False
@not_equal.register("Number", "None")
def _scalar_not_equal_none(x, y):
"""
Determine if number not equals none.
Args:
x: Number.
y: None.
Returns:
bool, return True.
"""
return True
@not_equal.register("None", "Number")
def _none_not_equal_scalar(x, y):
"""
Determine if number not_equals none.
Args:
x: None.
y: NUmber.
Returns:
bool, return True.
"""
return True
@not_equal.register("Tuple", "Tuple")
def _euqal_tuple(x, y):
"""
Determine if two tuples are not equal by element.
Args:
x (tuple): x
y (tuple): y
Returns:
bool, if x and y are not equal by element return true, else return false.
"""
return not F.tuple_equal(x, y)
@not_equal.register("List", "List")
def _euqal_list(x, y):
"""
Determine if two lists are not equal by element.
Args:
x (list): x
y (list): y
Returns:
bool, if x and y are not equal by element return true, else return false.
"""
return not F.list_equal(x, y)
@not_equal.register("Tuple", "None")
def _tuple_euqal_none(x, y):
"""
Determine if tuple element not equals none element.
Args:
x: Tuple.
y: None.
Returns:
bool, return True.
"""
return True
@not_equal.register("None", "Tuple")
def _none_not_equal_tuple(x, y):
"""
Determine if tuple element not equals none element.
Args:
x: None.
y: Tuple.
Returns:
bool, return True.
"""
return True
@not_equal.register("Tensor", "Number")
@not_equal.register("Number", "Tensor")
@not_equal.register("Tensor", "Tensor")
def _tensor_not_equal_tensor(x, y):
"""
Determine if two tensors are not_equal.
Args:
x : Tensor.
y : Tensor.
Returns:
bool, if x == y return true, x != y return false.
"""
return F.not_equal(x, y)
@not_equal.register("Tensor", "None")
def _tensor_not_equal_none(x, y):
"""
Determine if tensor not_equal none.
Args:
x : Tensor.
y : None.
Returns:
bool, return True.
"""
return True
@not_equal.register("None", "Tensor")
def _none_not_equal_tensor(x, y):
"""
Determine if tensor not equal none.
Args:
x : None.
y : Tensor.
Returns:
bool, return True.
"""
return True
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""uadd_impl"""
from mindspore.ops.composite import base
# uadd is a metagraph object which will return operation result regarding input
# using ".register" decorator
uadd = base.MultitypeFuncGraph("uadd")
@uadd.register("Tensor")
@uadd.register("Number")
def _uadd_scala(x):
return x
......@@ -43,12 +43,15 @@ tensor_add = P.TensorAdd()
neg_tensor = P.Neg()
tensor_lt = P.Less()
tensor_le = P.LessEqual()
tensor_gt = P.Greater()
tensor_ge = P.GreaterEqual()
tensor_sub = P.Sub()
tensor_mul = P.Mul()
tensor_div = P.RealDiv()
strided_slice = P.StridedSlice()
same_type_shape = P.SameTypeShape()
equal = P.Equal()
not_equal = P.NotEqual()
assign_sub = P.AssignSub()
assign = P.Assign()
square = P.Square()
......@@ -97,6 +100,7 @@ bool_or = Primitive("bool_or")
bool_and = Primitive("bool_and")
logical_and = P.LogicalAnd()
logical_or = P.LogicalOr()
logical_not = P.LogicalNot()
array_to_scalar = Primitive('array_to_scalar')
is_ = Primitive("is_")
is_not = Primitive("is_not")
......
......@@ -17,6 +17,7 @@ from mindspore.ops import Primitive
scala_add = Primitive('scalar_add')
scala_mul = Primitive('scalar_mul')
scalar_gt = Primitive('scalar_gt')
def scalar_add(x, y):
"""Implement `scalar_add`."""
return scala_add(x, y)
......@@ -26,6 +27,6 @@ def scalar_mul(x, y):
return scala_mul(x, y)
def test_if(x, y):
if x > y:
if scalar_gt(x, y):
return x
return y
......@@ -31,8 +31,20 @@ class ComparisonOpsNet(nn.Cell):
def __init__(self):
super(ComparisonOpsNet, self).__init__()
def construct(self, x, y):
ret = x <= y
return ret
a = x <= y
b = x <= 1.0
c = y >= 1.0
d = y >= x
e = x < y
f = x < 1.0
g = 1.0 > y
h = y > x
i = y == 3.0
j = x != 4
k = + x
l = + 1.0
m = k != l
return a or b or c or d or e or f or g or h or i or j or m
class LogicalNumberOpsNet(nn.Cell):
def __init__(self):
......@@ -41,7 +53,7 @@ class LogicalNumberOpsNet(nn.Cell):
self.one = 0
self.zero = 0.0
def construct(self, x, y):
if self.cond and self.one or self.zero:
if self.cond and self.one or self.zero and not self.one:
return x + y
return x - y
......@@ -51,7 +63,7 @@ class LogicalTensorOpsNet(nn.Cell):
super(LogicalTensorOpsNet, self).__init__()
self.const_true = Tensor(True, dtype=mstype.bool_)
def construct(self, x, y):
ret = x and y and (y or self.const_true)
ret = x and y and (y or self.const_true) and (not self.const_true)
return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册