未验证 提交 c13cf739 编写于 作者: Z Zhou Wei 提交者: GitHub

set_default_type only take effect on python floats or complex (#26939) (#26940)

* set_default_type only take effect on python floats or complex

* fix doc
上级 89af2088
......@@ -33,16 +33,28 @@ class TestVarBase(unittest.TestCase):
def _test_place(place):
with fluid.dygraph.guard():
paddle.set_default_dtype('float32')
# set_default_dtype should not take effect on int
x = paddle.to_tensor(1, place=place, stop_gradient=False)
self.assertTrue(np.array_equal(x.numpy(), [1]))
self.assertNotEqual(x.dtype, core.VarDesc.VarType.FP32)
# set_default_dtype should not take effect on numpy
x = paddle.to_tensor(
np.array([1.2]).astype('float16'),
place=place,
stop_gradient=False)
self.assertTrue(
np.array_equal(x.numpy(), np.array([1.2], 'float16')))
self.assertEqual(x.dtype, core.VarDesc.VarType.FP16)
# set_default_dtype take effect on float
x = paddle.to_tensor(1.2, place=place, stop_gradient=False)
self.assertTrue(
np.array_equal(x.numpy(), np.array([1.2]).astype(
'float32')))
self.assertEqual(x.dtype, core.VarDesc.VarType.FP32)
# set_default_dtype take effect on complex
x = paddle.to_tensor(1 + 2j, place=place, stop_gradient=False)
self.assertTrue(np.array_equal(x.numpy(), [1 + 2j]))
self.assertEqual(x.dtype, 'complex64')
......
......@@ -73,8 +73,8 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor, paddle\.ComplexTensor.
dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8'. And
'complex64' , 'complex128' only for ComplexTensor. Default: None, for float point number,
get type from ``get_default_type``, for other type, infers from ``data`` .
'complex64' , 'complex128' only for ComplexTensor. Default: None, infers dtype from ``data``
except for python float number which gets dtype from ``get_default_type`` .
place(CPUPlace|CUDAPinnedPlace|CUDAPlace, optional): The place to allocate Tensor. Can be
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place.
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.
......@@ -188,13 +188,21 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
raise TypeError(
"Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor|paddle.ComplexTensor".
format(type(data)))
if not dtype and data.dtype in [
'float16', 'float32', 'float64', 'complex64', 'complex128'
]:
default_type = paddle.get_default_dtype()
if np.iscomplexobj(data):
default_type = 'complex64' if default_type in [
'float16', 'float32'
] else 'complex128'
data = data.astype(default_type)
if dtype and convert_dtype(dtype) != data.dtype:
data = data.astype(dtype)
if not np.iscomplexobj(data):
if dtype:
dtype = convert_dtype(dtype)
elif data.dtype in ['float16', 'float32', 'float64']:
dtype = paddle.framework.get_default_dtype()
if dtype and dtype != data.dtype:
if dtype and convert_dtype(dtype) != data.dtype:
data = data.astype(dtype)
return paddle.Tensor(
value=data,
......@@ -203,14 +211,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
zero_copy=True,
stop_gradient=stop_gradient)
else:
if dtype:
dtype = convert_dtype(dtype)
else:
dtype = paddle.framework.get_default_dtype()
dtype = 'complex64' if dtype in ['float16', 'float32'
] else 'complex128'
if dtype != data.dtype:
data = data.astype(dtype)
name = unique_name.generate('generated_tensor')
real_tensor = paddle.Tensor(
value=data.real,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册