提交 dbce6526 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix return dtype of comparison function

GitOrigin-RevId: 810e32a829ea2b1d0835b3791a6521549f867de1
上级 7dc34769
......@@ -626,7 +626,7 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor:
def equal(x, y):
r"""Element-wise `(x == y)`."""
return _elwise(x, y, mode=Elemwise.Mode.EQ)
return x == y
def not_equal(x, y):
......@@ -636,22 +636,22 @@ def not_equal(x, y):
def less(x, y):
r"""Element-wise `(x < y)`."""
return _elwise(x, y, mode=Elemwise.Mode.LT)
return x < y
def less_equal(x, y):
r"""Element-wise `(x <= y)`."""
return _elwise(x, y, mode=Elemwise.Mode.LEQ)
return x <= y
def greater(x, y):
r"""Element-wise `(x > y)`."""
return _elwise(y, x, mode=Elemwise.Mode.LT)
return x > y
def greater_equal(x, y):
r"""Element-wise `(x >= y)`."""
return _elwise(y, x, mode=Elemwise.Mode.LEQ)
return x >= y
# other functions
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册