提交 717b88e6 编写于 作者: M Megvii Engine Team

fix(mge/elemwise): fix problem that elemwise.mode is not comparable with string mode

GitOrigin-RevId: 82e39be0a975cc72dfe1fe7c206be218c2ada131
上级 9c17cfc4
...@@ -191,10 +191,13 @@ struct EnumWrapper { ...@@ -191,10 +191,13 @@ struct EnumWrapper {
.release().ptr(); .release().ptr();
} }
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) {
T lhs = reinterpret_cast<EnumWrapper*>(self)->value,
rhs = reinterpret_cast<EnumWrapper*>(other)->value;
if (op == Py_EQ || op == Py_NE) { if (op == Py_EQ || op == Py_NE) {
T lhs, rhs;
if (load(other, rhs) && load(self, lhs)) {
RETURN_RICHCOMPARE(lhs, rhs, op); RETURN_RICHCOMPARE(lhs, rhs, op);
} else {
RETURN_RICHCOMPARE(0, 1, op);
}
} }
Py_RETURN_NOTIMPLEMENTED; Py_RETURN_NOTIMPLEMENTED;
} }
...@@ -296,10 +299,13 @@ struct BitCombinedEnumWrapper { ...@@ -296,10 +299,13 @@ struct BitCombinedEnumWrapper {
return cast(lhs & rhs); return cast(lhs & rhs);
} }
static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
if (op == Py_EQ || op == Py_NE) { if (op == Py_EQ || op == Py_NE) {
T lhs, rhs;
if (load(other, rhs) && load(self, lhs)) {
RETURN_RICHCOMPARE(lhs, rhs, op); RETURN_RICHCOMPARE(lhs, rhs, op);
} else {
RETURN_RICHCOMPARE(0, 1, op);
}
} }
Py_RETURN_NOTIMPLEMENTED; Py_RETURN_NOTIMPLEMENTED;
} }
......
...@@ -12,7 +12,7 @@ import megengine.functional as F ...@@ -12,7 +12,7 @@ import megengine.functional as F
import megengine.functional.elemwise as elemwise import megengine.functional.elemwise as elemwise
from megengine import tensor from megengine import tensor
from megengine.core.tensor import dtype from megengine.core.tensor import dtype
from megengine.functional.elemwise import _elwise from megengine.functional.elemwise import Elemwise, _elwise
def test_abs(): def test_abs():
...@@ -25,14 +25,10 @@ def test_abs(): ...@@ -25,14 +25,10 @@ def test_abs():
def test_elemwise_mode_string(): def test_elemwise_mode_string():
np.testing.assert_allclose( for key, mode in vars(Elemwise.Mode).items():
_elwise(tensor([-3.0, -4.0, -5.0]), mode="ABS").numpy(), if isinstance(mode, Elemwise.Mode):
np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)), assert key == mode
) assert Elemwise(mode=key) == Elemwise(mode=mode)
np.testing.assert_allclose(
_elwise(-3.0, mode="ABS").numpy(), np.abs(np.float32(-3.0))
)
def test_multiply(): def test_multiply():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册