未验证 提交 245252b8 编写于 作者: Z Zhou Wei 提交者: GitHub

fix bug when dtype of to_tensor is core.VarType (#31931)

上级 e1f93161
...@@ -76,6 +76,11 @@ class TestVarBase(unittest.TestCase): ...@@ -76,6 +76,11 @@ class TestVarBase(unittest.TestCase):
y = x.cuda(blocking=True) y = x.cuda(blocking=True)
self.assertEqual(y.place.__repr__(), "CUDAPlace(0)") self.assertEqual(y.place.__repr__(), "CUDAPlace(0)")
# support 'dtype' is core.VarType
x = paddle.rand((2, 2))
y = paddle.to_tensor([2, 2], dtype=x.dtype)
self.assertEqual(y.dtype, core.VarDesc.VarType.FP32)
# set_default_dtype take effect on complex # set_default_dtype take effect on complex
x = paddle.to_tensor(1 + 2j, place=place, stop_gradient=False) x = paddle.to_tensor(1 + 2j, place=place, stop_gradient=False)
self.assertTrue(np.array_equal(x.numpy(), [1 + 2j])) self.assertTrue(np.array_equal(x.numpy(), [1 + 2j]))
......
...@@ -168,7 +168,7 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -168,7 +168,7 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
data = data.astype(default_type) data = data.astype(default_type)
if dtype and convert_dtype(dtype) != data.dtype: if dtype and convert_dtype(dtype) != data.dtype:
data = data.astype(dtype) data = data.astype(convert_dtype(dtype))
return paddle.Tensor( return paddle.Tensor(
value=data, value=data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册