diff --git a/imperative/python/src/helper.cpp b/imperative/python/src/helper.cpp index d4f325361b471919b167dcfc217e6648c49ff94f..c17ff1d633ee5051a390ca328f587e09764e28f6 100644 --- a/imperative/python/src/helper.cpp +++ b/imperative/python/src/helper.cpp @@ -172,6 +172,7 @@ int to_mgb_supported_dtype_raw(int dtype) { #define FOREACH_NPY_DTYPE_PAIR(cb) \ cb(Uint8, NPY_UINT8) \ cb(Int8, NPY_INT8) \ + cb(Uint16, NPY_UINT16) \ cb(Int16, NPY_INT16) \ cb(Int32, NPY_INT32) \ cb(Float16, NPY_FLOAT16) \ diff --git a/imperative/python/test/unit/core/test_raw_tensor.py b/imperative/python/test/unit/core/test_raw_tensor.py index c8a88453108f9243d82a18b79a4cb77763870aa3..6ef599cc689640e061a4e43a524eb10c55fba02d 100644 --- a/imperative/python/test/unit/core/test_raw_tensor.py +++ b/imperative/python/test/unit/core/test_raw_tensor.py @@ -28,3 +28,10 @@ def test_as_raw_tensor_from_int64(): assert xx.dtype == np.float32 assert xx.device == "xpux" np.testing.assert_almost_equal(yy, x.astype("float32") + 1) + + +def test_as_raw_tensor_uint16(): + x = np.arange(6, dtype="uint16").reshape(2, 3) + xx = Tensor(x, device="xpux") + assert xx.dtype == np.uint16 + assert xx.device == "xpux"