未验证 提交 4617c1b2 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

fix bug of paddle.to_tensor and paddle.moveaxis (#39662)

* fix bug of paddle.to_tensor and paddle.moveaxis

* fix CI
上级 69ab2700
...@@ -423,6 +423,14 @@ class TestMoveAxis(unittest.TestCase): ...@@ -423,6 +423,14 @@ class TestMoveAxis(unittest.TestCase):
self.assertEqual(np.array_equal(out.numpy(), expected), True) self.assertEqual(np.array_equal(out.numpy(), expected), True)
paddle.enable_static() paddle.enable_static()
def test_moveaxis3(self):
paddle.disable_static()
x = paddle.to_tensor(
[[1 + 1j, -1 - 1j], [1 + 1j, -1 - 1j], [1 + 1j, -1 - 1j]])
out = x.moveaxis(0, 1)
self.assertEqual(out.shape, [2, 3])
paddle.enable_static()
def test_error(self): def test_error(self):
x = paddle.randn([2, 3, 4, 5]) x = paddle.randn([2, 3, 4, 5])
# src must have the same number with dst # src must have the same number with dst
......
...@@ -51,6 +51,10 @@ class TestVarBase(unittest.TestCase): ...@@ -51,6 +51,10 @@ class TestVarBase(unittest.TestCase):
np.array_equal(x.numpy(), np.array([1.2], 'float16'))) np.array_equal(x.numpy(), np.array([1.2], 'float16')))
self.assertEqual(x.dtype, core.VarDesc.VarType.FP16) self.assertEqual(x.dtype, core.VarDesc.VarType.FP16)
# set_default_dtype take effect on int
x = paddle.to_tensor(1, place=place)
self.assertTrue(x.dtype, core.VarDesc.VarType.INT64)
# set_default_dtype take effect on float # set_default_dtype take effect on float
x = paddle.to_tensor(1.2, place=place, stop_gradient=False) x = paddle.to_tensor(1.2, place=place, stop_gradient=False)
self.assertTrue( self.assertTrue(
......
...@@ -110,12 +110,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -110,12 +110,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace" "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace"
) )
#Todo(zhouwei): Support allocate tensor on any other specified card
if isinstance(place, core.CUDAPlace) and isinstance(
_current_expected_place(), core.CUDAPlace) and place._get_device_id(
) != _current_expected_place()._get_device_id():
place = _current_expected_place()
if not isinstance(data, np.ndarray): if not isinstance(data, np.ndarray):
def _handle_dtype(data, dtype): def _handle_dtype(data, dtype):
...@@ -139,7 +133,7 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -139,7 +133,7 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
data.stop_gradient = stop_gradient data.stop_gradient = stop_gradient
return data return data
elif isinstance(data, (core.LoDTensor, core.Tensor)): elif isinstance(data, (core.LoDTensor, core.Tensor)):
# Note(zhouwei25): should't expose it to users, just for internal use. # should't expose it to users, just for internal use.
# convert core.Tensor/core.LoDTensor to VarBase first # convert core.Tensor/core.LoDTensor to VarBase first
# Currenly, there is no copy when places are same # Currenly, there is no copy when places are same
data = paddle.Tensor(data) data = paddle.Tensor(data)
...@@ -152,15 +146,20 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -152,15 +146,20 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
raise TypeError( raise TypeError(
"Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor". "Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor".
format(type(data))) format(type(data)))
if not dtype and data.dtype in [ if not dtype:
'float16', 'float32', 'float64', 'complex64', 'complex128' if data.dtype in [
]: 'float16', 'float32', 'float64', 'complex64', 'complex128'
default_type = paddle.get_default_dtype() ]:
if np.iscomplexobj(data): default_type = paddle.get_default_dtype()
default_type = 'complex64' if default_type in [ if np.iscomplexobj(data):
'float16', 'float32' default_type = 'complex64' if default_type in [
] else 'complex128' 'float16', 'float32'
data = data.astype(default_type) ] else 'complex128'
data = data.astype(default_type)
# Windows default type is 'int32', while Linux/Mac is 'int64'. Unify they.
if data.dtype in ['int32']:
default_type = "int64"
data = data.astype(default_type)
if dtype and convert_dtype(dtype) != data.dtype: if dtype and convert_dtype(dtype) != data.dtype:
data = data.astype(convert_dtype(dtype)) data = data.astype(convert_dtype(dtype))
......
...@@ -2737,9 +2737,10 @@ def moveaxis(x, source, destination, name=None): ...@@ -2737,9 +2737,10 @@ def moveaxis(x, source, destination, name=None):
out, _ = _C_ops.transpose2(x, 'axis', perm) out, _ = _C_ops.transpose2(x, 'axis', perm)
return out return out
check_variable_and_dtype( check_variable_and_dtype(x, 'x', [
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
'moveaxis') 'complex128'
], 'moveaxis')
helper = LayerHelper('moveaxis', **locals()) helper = LayerHelper('moveaxis', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册