From 245252b86e9878373754db8c66fad35b38cd8e1a Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Tue, 30 Mar 2021 15:57:36 +0800 Subject: [PATCH] fix bug when dtype of to_tensor is core.VarType (#31931) --- python/paddle/fluid/tests/unittests/test_var_base.py | 5 +++++ python/paddle/tensor/creation.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index b0c9dda7a30..1fea1935473 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -76,6 +76,11 @@ class TestVarBase(unittest.TestCase): y = x.cuda(blocking=True) self.assertEqual(y.place.__repr__(), "CUDAPlace(0)") + # support 'dtype' is core.VarType + x = paddle.rand((2, 2)) + y = paddle.to_tensor([2, 2], dtype=x.dtype) + self.assertEqual(y.dtype, core.VarDesc.VarType.FP32) + # set_default_dtype take effect on complex x = paddle.to_tensor(1 + 2j, place=place, stop_gradient=False) self.assertTrue(np.array_equal(x.numpy(), [1 + 2j])) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 056a0226723..69ee2962303 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -168,7 +168,7 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): data = data.astype(default_type) if dtype and convert_dtype(dtype) != data.dtype: - data = data.astype(dtype) + data = data.astype(convert_dtype(dtype)) return paddle.Tensor( value=data, -- GitLab