未验证 提交 178b1b14 编写于 作者: F feifei-111 提交者: GitHub

[API] fix bug in to_tensor which returns err dtype (#45418)

* DLTP-55757 [Bug] 【Paddle3D】pointpillars在paddle=dev上export model报错

* DLTP-55757 [Bug] 【Paddle3D】pointpillars在paddle=dev上export model报错

* fix default dtype

* coverage

* del no need cases
上级 9a560f7c
......@@ -33,7 +33,7 @@ def case0(x):
def case1(x):
paddle.set_default_dtype("float64")
a = paddle.to_tensor([1.0, 2.0, 3.0], stop_gradient=False)
a = paddle.to_tensor([1, 2, 3], stop_gradient=False, dtype='float32')
return a
......@@ -68,13 +68,27 @@ def case4(x):
place = paddle.CUDAPlace(0)
else:
place = paddle.CPUPlace()
a = paddle.to_tensor([1.0], place=place, dtype="float64")
b = paddle.to_tensor([2], place=place, stop_gradient=False, dtype="int64")
a = paddle.to_tensor([1], place=place)
b = paddle.to_tensor([2.1], place=place, stop_gradient=False, dtype="int64")
c = paddle.to_tensor([a, b, [1]], dtype="float32")
return c
def case5(x):
paddle.set_default_dtype("float64")
a = paddle.to_tensor([1, 2])
return a
def case6(x):
na = numpy.array([1, 2], dtype='int32')
a = paddle.to_tensor(na)
return a
class TestToTensorReturnVal(unittest.TestCase):
def test_to_tensor_badreturn(self):
......@@ -111,6 +125,18 @@ class TestToTensorReturnVal(unittest.TestCase):
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))
a = paddle.jit.to_static(case5)(x)
b = case5(x)
self.assertTrue(a.dtype == b.dtype)
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))
a = paddle.jit.to_static(case6)(x)
b = case6(x)
self.assertTrue(a.dtype == b.dtype)
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))
class TestStatic(unittest.TestCase):
......
......@@ -354,21 +354,29 @@ def _to_tensor_static(data, dtype=None, stop_gradient=None):
if isinstance(data, Variable) and (dtype is None or dtype == data.dtype):
output = data
else:
if not isinstance(data, np.ndarray):
if np.isscalar(data) and not isinstance(data, str):
data = np.array([data])
elif isinstance(data, (list, tuple)):
data = np.array(data)
if isinstance(data,
np.ndarray) and not dtype and data.dtype != 'object':
if data.dtype in ['float16', 'float32', 'float64']:
data = data.astype(paddle.get_default_dtype())
elif data.dtype in ['int32']:
data = data.astype('int64')
if dtype:
target_dtype = dtype
elif hasattr(data, 'dtype'):
elif hasattr(data, 'dtype') and data.dtype != 'object':
target_dtype = data.dtype
else:
target_dtype = paddle.get_default_dtype()
target_dtype = convert_dtype(target_dtype)
if not isinstance(data, np.ndarray):
if np.isscalar(data) and not isinstance(data, str):
data = np.array([data])
elif isinstance(data, (list, tuple)):
data = np.array(data)
if isinstance(data, np.ndarray) and len(data.shape) > 0 and any(
isinstance(x, Variable) for x in data):
if not all(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册