提交 78d21894 编写于 作者: M Megvii Engine Team

feat(mge): rename comparison operators to their full name

GitOrigin-RevId: e503b038a1c1d57d5202d35f7d16f8f09f504a22
上级 e034874a
...@@ -30,19 +30,19 @@ __all__ = [ ...@@ -30,19 +30,19 @@ __all__ = [
"cos", "cos",
"cosh", "cosh",
"div", "div",
"eq", "equal",
"exp", "exp",
"expm1", "expm1",
"fast_tanh", "fast_tanh",
"floor", "floor",
"floor_div", "floor_div",
"gt", "greater",
"ge", "greater_equal",
"hswish", "hswish",
"hsigmoid", "hsigmoid",
"left_shift", "left_shift",
"lt", "less",
"le", "less_equal",
"log", "log",
"log1p", "log1p",
"logical_and", "logical_and",
...@@ -54,7 +54,7 @@ __all__ = [ ...@@ -54,7 +54,7 @@ __all__ = [
"mod", "mod",
"mul", "mul",
"neg", "neg",
"ne", "not_equal",
"pow", "pow",
"relu", "relu",
"relu6", "relu6",
...@@ -102,7 +102,7 @@ def add(x, y): ...@@ -102,7 +102,7 @@ def add(x, y):
"""Element-wise `addition`. """Element-wise `addition`.
At least one operand should be tensor. At least one operand should be tensor.
Same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium. Same for sub/mul/div/floor_div/pow/mod/atan2/equal/not_equal/less/less_equal/greater/greater_equal/maximum/minmium.
:param x: input tensor. :param x: input tensor.
:return: computed tensor. :return: computed tensor.
...@@ -442,7 +442,7 @@ def logical_xor(x, y): ...@@ -442,7 +442,7 @@ def logical_xor(x, y):
# comparison functions # comparison functions
def eq(x, y): def equal(x, y):
"""Element-wise `(x == y)`. """Element-wise `(x == y)`.
:param x: input tensor 1. :param x: input tensor 1.
...@@ -459,7 +459,7 @@ def eq(x, y): ...@@ -459,7 +459,7 @@ def eq(x, y):
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.eq(x, y) out = F.equal(x, y)
print(out.numpy()) print(out.numpy())
Outputs: Outputs:
...@@ -473,27 +473,27 @@ def eq(x, y): ...@@ -473,27 +473,27 @@ def eq(x, y):
return _elwise(x, y, mode="eq") return _elwise(x, y, mode="eq")
def ne(x, y): def not_equal(x, y):
"""Element-wise `(x != y)`.""" """Element-wise `(x != y)`."""
return x != y return x != y
def lt(x, y): def less(x, y):
"""Element-wise `(x < y)`.""" """Element-wise `(x < y)`."""
return _elwise(x, y, mode="lt") return _elwise(x, y, mode="lt")
def le(x, y): def less_equal(x, y):
"""Element-wise `(x <= y)`.""" """Element-wise `(x <= y)`."""
return _elwise(x, y, mode="leq") return _elwise(x, y, mode="leq")
def gt(x, y): def greater(x, y):
"""Element-wise `(x > y)`.""" """Element-wise `(x > y)`."""
return _elwise(y, x, mode="lt") return _elwise(y, x, mode="lt")
def ge(x, y): def greater_equal(x, y):
"""Element-wise `(x >= y)`.""" """Element-wise `(x >= y)`."""
return _elwise(y, x, mode="leq") return _elwise(y, x, mode="leq")
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
from ..core.tensor.utils import make_shape_tuple from ..core.tensor.utils import make_shape_tuple
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import abs, eq, exp, log, maximum, pow, relu from .elemwise import abs, equal, exp, log, maximum, pow, relu
from .nn import indexing_one_hot from .nn import indexing_one_hot
from .tensor import where from .tensor import where
......
...@@ -56,9 +56,9 @@ class Elemwise(Module): ...@@ -56,9 +56,9 @@ class Elemwise(Module):
* "SIGMOID_GRAD": sigmoid_grad * "SIGMOID_GRAD": sigmoid_grad
* "SWITCH_GT0": switch_gt0 * "SWITCH_GT0": switch_gt0
* "TANH_GRAD": tanh_grad * "TANH_GRAD": tanh_grad
* "LT": lt * "LT": less
* "LEQ": leq * "LEQ": leq
* "EQ": eq * "EQ": equal
* "POW": pow * "POW": pow
* "LOG_SUM_EXP": log_sum_exp * "LOG_SUM_EXP": log_sum_exp
* "FAST_TANH_GRAD": fast_tanh_grad * "FAST_TANH_GRAD": fast_tanh_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册