diff --git a/imperative/python/src/numpy_dtypes.cpp b/imperative/python/src/numpy_dtypes.cpp index c4d62a8362dab5f4afe62855efcdc34dc0d68ea8..5543aa4f34a14bad1157ca7f55c5c1b75f3fd808 100644 --- a/imperative/python/src/numpy_dtypes.cpp +++ b/imperative/python/src/numpy_dtypes.cpp @@ -91,9 +91,8 @@ bool _is_dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { PyDict_GetItemString(dt1->metadata, "mgb_dtype"), "zero_point"); PyObject* zp2 = PyDict_GetItemString( PyDict_GetItemString(dt2->metadata, "mgb_dtype"), "zero_point"); - if (!zp1 || !zp2) { - throw py::key_error("zero_point"); - } + if (!zp1 && !zp2) return true; + if (!zp1 || !zp2) return false; return PyLong_AsLong(zp1) == PyLong_AsLong(zp2); } if (!q1 && !q2) {