diff --git a/imperative/python/src/helper.cpp b/imperative/python/src/helper.cpp index ede45bd76f5349accfeb4f229a3f6c098f079066..fc4848902402a8491ff5ac2f95059c10ab98e9cd 100644 --- a/imperative/python/src/helper.cpp +++ b/imperative/python/src/helper.cpp @@ -675,8 +675,17 @@ PyObject* dtype_mgb2np(mgb::DType dtype) { Py_XINCREF(Py_None); return Py_None; } - // NOTE: the following is additional - return reinterpret_cast(descr.release()); + static bool use_typeobj_as_dtype = MGB_GETENV("MGE_USE_TYPEOBJ_AS_DTYPE"); + if (use_typeobj_as_dtype) { + if (dtype.has_param()) { + return reinterpret_cast(descr.release()); + } + PyObject* typeobj = reinterpret_cast(descr->typeobj); + Py_XINCREF(typeobj); + return typeobj; + } else { + return reinterpret_cast(descr.release()); + } } mgb::DType dtype_np2mgb(PyObject* obj) {