提交 b6142bee 编写于 作者: M Megvii Engine Team

feat(imperative): support tensor with uint16 date type

GitOrigin-RevId: 57ba0633c7f344922a95f2ac03a7502231453e52
上级 3a7bc37f
...@@ -172,6 +172,7 @@ int to_mgb_supported_dtype_raw(int dtype) { ...@@ -172,6 +172,7 @@ int to_mgb_supported_dtype_raw(int dtype) {
#define FOREACH_NPY_DTYPE_PAIR(cb) \ #define FOREACH_NPY_DTYPE_PAIR(cb) \
cb(Uint8, NPY_UINT8) \ cb(Uint8, NPY_UINT8) \
cb(Int8, NPY_INT8) \ cb(Int8, NPY_INT8) \
cb(Uint16, NPY_UINT16) \
cb(Int16, NPY_INT16) \ cb(Int16, NPY_INT16) \
cb(Int32, NPY_INT32) \ cb(Int32, NPY_INT32) \
cb(Float16, NPY_FLOAT16) \ cb(Float16, NPY_FLOAT16) \
......
...@@ -28,3 +28,10 @@ def test_as_raw_tensor_from_int64(): ...@@ -28,3 +28,10 @@ def test_as_raw_tensor_from_int64():
assert xx.dtype == np.float32 assert xx.dtype == np.float32
assert xx.device == "xpux" assert xx.device == "xpux"
np.testing.assert_almost_equal(yy, x.astype("float32") + 1) np.testing.assert_almost_equal(yy, x.astype("float32") + 1)
def test_as_raw_tensor_uint16():
x = np.arange(6, dtype="uint16").reshape(2, 3)
xx = Tensor(x, device="xpux")
assert xx.dtype == np.uint16
assert xx.device == "xpux"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册