提交 717b88e6 编写于 作者: M Megvii Engine Team

fix(mge/elemwise): fix problem that elemwise.mode is not comparable with string mode

GitOrigin-RevId: 82e39be0a975cc72dfe1fe7c206be218c2ada131
上级 9c17cfc4
......@@ -191,10 +191,13 @@ struct EnumWrapper {
.release().ptr();
}
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) {
T lhs = reinterpret_cast<EnumWrapper*>(self)->value,
rhs = reinterpret_cast<EnumWrapper*>(other)->value;
if (op == Py_EQ || op == Py_NE) {
T lhs, rhs;
if (load(other, rhs) && load(self, lhs)) {
RETURN_RICHCOMPARE(lhs, rhs, op);
} else {
RETURN_RICHCOMPARE(0, 1, op);
}
}
Py_RETURN_NOTIMPLEMENTED;
}
......@@ -296,10 +299,13 @@ struct BitCombinedEnumWrapper {
return cast(lhs & rhs);
}
static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
if (op == Py_EQ || op == Py_NE) {
T lhs, rhs;
if (load(other, rhs) && load(self, lhs)) {
RETURN_RICHCOMPARE(lhs, rhs, op);
} else {
RETURN_RICHCOMPARE(0, 1, op);
}
}
Py_RETURN_NOTIMPLEMENTED;
}
......
......@@ -12,7 +12,7 @@ import megengine.functional as F
import megengine.functional.elemwise as elemwise
from megengine import tensor
from megengine.core.tensor import dtype
from megengine.functional.elemwise import _elwise
from megengine.functional.elemwise import Elemwise, _elwise
def test_abs():
......@@ -25,14 +25,10 @@ def test_abs():
def test_elemwise_mode_string():
np.testing.assert_allclose(
_elwise(tensor([-3.0, -4.0, -5.0]), mode="ABS").numpy(),
np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)),
)
np.testing.assert_allclose(
_elwise(-3.0, mode="ABS").numpy(), np.abs(np.float32(-3.0))
)
for key, mode in vars(Elemwise.Mode).items():
if isinstance(mode, Elemwise.Mode):
assert key == mode
assert Elemwise(mode=key) == Elemwise(mode=mode)
def test_multiply():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册