提交 78d21894 编写于 作者: M Megvii Engine Team

feat(mge): rename comparison operators to their full name

GitOrigin-RevId: e503b038a1c1d57d5202d35f7d16f8f09f504a22
上级 e034874a
......@@ -30,19 +30,19 @@ __all__ = [
"cos",
"cosh",
"div",
"eq",
"equal",
"exp",
"expm1",
"fast_tanh",
"floor",
"floor_div",
"gt",
"ge",
"greater",
"greater_equal",
"hswish",
"hsigmoid",
"left_shift",
"lt",
"le",
"less",
"less_equal",
"log",
"log1p",
"logical_and",
......@@ -54,7 +54,7 @@ __all__ = [
"mod",
"mul",
"neg",
"ne",
"not_equal",
"pow",
"relu",
"relu6",
......@@ -102,7 +102,7 @@ def add(x, y):
"""Element-wise `addition`.
At least one operand should be tensor.
Same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium.
Same for sub/mul/div/floor_div/pow/mod/atan2/equal/not_equal/less/less_equal/greater/greater_equal/maximum/minmium.
:param x: input tensor.
:return: computed tensor.
......@@ -442,7 +442,7 @@ def logical_xor(x, y):
# comparison functions
def eq(x, y):
def equal(x, y):
"""Element-wise `(x == y)`.
:param x: input tensor 1.
......@@ -459,7 +459,7 @@ def eq(x, y):
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.eq(x, y)
out = F.equal(x, y)
print(out.numpy())
Outputs:
......@@ -473,27 +473,27 @@ def eq(x, y):
return _elwise(x, y, mode="eq")
def ne(x, y):
def not_equal(x, y):
"""Element-wise `(x != y)`."""
return x != y
def lt(x, y):
def less(x, y):
"""Element-wise `(x < y)`."""
return _elwise(x, y, mode="lt")
def le(x, y):
def less_equal(x, y):
"""Element-wise `(x <= y)`."""
return _elwise(x, y, mode="leq")
def gt(x, y):
def greater(x, y):
"""Element-wise `(x > y)`."""
return _elwise(y, x, mode="lt")
def ge(x, y):
def greater_equal(x, y):
"""Element-wise `(x >= y)`."""
return _elwise(y, x, mode="leq")
......
......@@ -10,7 +10,7 @@ import numpy as np
from ..core.tensor.utils import make_shape_tuple
from ..tensor import Tensor
from .elemwise import abs, eq, exp, log, maximum, pow, relu
from .elemwise import abs, equal, exp, log, maximum, pow, relu
from .nn import indexing_one_hot
from .tensor import where
......
......@@ -56,9 +56,9 @@ class Elemwise(Module):
* "SIGMOID_GRAD": sigmoid_grad
* "SWITCH_GT0": switch_gt0
* "TANH_GRAD": tanh_grad
* "LT": lt
* "LT": less
* "LEQ": leq
* "EQ": eq
* "EQ": equal
* "POW": pow
* "LOG_SUM_EXP": log_sum_exp
* "FAST_TANH_GRAD": fast_tanh_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册