未验证 提交 116fcada 编写于 作者: F feifei-111 提交者: GitHub

【BugFix】fix err of api `to_tensor`, which caused by numpy version update (#53534)

* fix

* update code

* pre-commit

* remove scale check (0-D tensor is usable)

* fix data dtype err

* fix numpy default dtype diff

* fix data dtype

* fix data dtype

* update

* fix coverage
上级 f74237cd
......@@ -646,8 +646,10 @@ def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
def _to_tensor_static(data, dtype=None, stop_gradient=None):
if isinstance(data, Variable) and (dtype is None or dtype == data.dtype):
if isinstance(data, Variable):
output = data
if dtype is not None and dtype != data.dtype:
output = paddle.cast(output, dtype)
else:
if isinstance(data, np.number): # Special case for numpy scalars
data = np.array(data)
......@@ -656,17 +658,37 @@ def _to_tensor_static(data, dtype=None, stop_gradient=None):
if np.isscalar(data) and not isinstance(data, str):
data = np.array(data)
elif isinstance(data, (list, tuple)):
try:
'''
In numpy version >= 1.24.0, case like:
np.array([Variable, 1, 2])
is not supported, it will raise error (numpy returns an numpy array with dtype='object' in version <= 1.23.5)
Thus, process nested structure in except block
'''
data = np.array(data)
if (
isinstance(data, np.ndarray)
and not dtype
and data.dtype != 'object'
):
# for numpy version <= 1.23.5
if data.dtype == 'object':
raise RuntimeError("Numpy get dtype `object`.")
except:
to_stack_list = [None] * len(data)
for idx, d in enumerate(data):
to_stack_list[idx] = _to_tensor_static(
d, dtype, stop_gradient
)
data = paddle.stack(to_stack_list)
data = paddle.squeeze(data, -1)
else:
raise RuntimeError(
f"Do not support transform type `{type(data)}` to tensor"
)
# fix numpy default dtype
if data.dtype in ['float16', 'float32', 'float64']:
data = data.astype(paddle.get_default_dtype())
elif data.dtype in ['int32']:
data = data.astype('int64')
if dtype:
target_dtype = dtype
......@@ -674,24 +696,10 @@ def _to_tensor_static(data, dtype=None, stop_gradient=None):
target_dtype = data.dtype
else:
target_dtype = paddle.get_default_dtype()
target_dtype = convert_dtype(target_dtype)
if (
isinstance(data, np.ndarray)
and len(data.shape) > 0
and any(isinstance(x, Variable) for x in data)
):
to_stack_list = [None] * data.shape[0]
for idx, d in enumerate(data):
to_stack_list[idx] = _to_tensor_static(d, dtype, stop_gradient)
data = paddle.stack(to_stack_list)
data = paddle.squeeze(data, -1)
if not isinstance(data, Variable):
output = assign(data)
else:
output = data
if convert_dtype(output.dtype) != target_dtype:
output = paddle.cast(output, target_dtype)
......
......@@ -90,6 +90,11 @@ def case7(x):
return a
def case8(x):
a = paddle.to_tensor({1: 1})
return a
class TestToTensorReturnVal(unittest.TestCase):
def test_to_tensor_badreturn(self):
paddle.disable_static()
......@@ -143,6 +148,17 @@ class TestToTensorReturnVal(unittest.TestCase):
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))
def test_to_tensor_err_log(self):
paddle.disable_static()
x = paddle.to_tensor([3])
try:
a = paddle.jit.to_static(case8)(x)
except Exception as e:
self.assertTrue(
"Do not support transform type `<class 'dict'>` to tensor"
in str(e)
)
class TestStatic(unittest.TestCase):
def test_static(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册