From fe15239ac04c1d6702ef901da94bbb1bb2a060f2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 14 Oct 2021 12:44:04 +0800 Subject: [PATCH] fix(imperative): fix error message for tensors with intbx data type GitOrigin-RevId: cbb42f8127320c4d45ac6dcb8171515e53e69bcb --- imperative/python/src/numpy_dtypes_intbx.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/imperative/python/src/numpy_dtypes_intbx.cpp b/imperative/python/src/numpy_dtypes_intbx.cpp index 07cb9df23..661e56bac 100644 --- a/imperative/python/src/numpy_dtypes_intbx.cpp +++ b/imperative/python/src/numpy_dtypes_intbx.cpp @@ -24,7 +24,7 @@ template struct LowBitType { static_assert(N < 8, "low bit only supports less than 8 bits"); static int npy_typenum; - //! numerical value (-3, -1, 1, 3) + //! allowed numerical value: odd numbers between (-max_value, max_value) int8_t value; struct PyObj; @@ -32,16 +32,17 @@ struct LowBitType { const static int32_t max_value = (1 << N) - 1; - //! check whether val is (-3, -1, 1, 3) and set python error + //! check whether val is odd and between (-max_value, max_value) and set python error static bool check_value_set_err(int val) { int t = val + max_value; if ((t & 1) || t < 0 || t > (max_value << 1)) { PyErr_SetString( - PyExc_ValueError, mgb::ssprintf( - "low bit dtype number error: " - "value=%d; allowed {-3, -1, 1, 3}", - val) - .c_str()); + PyExc_ValueError, + mgb::ssprintf( + "low bit dtype number error: " + "value=%d; allowed values are odd numbers between [%d,%d]", + val, -max_value, max_value) + .c_str()); return false; } -- GitLab