From 116fcadaf19b7cbbf6443a6f0788067381d2dbf4 Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Mon, 8 May 2023 19:09:45 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90BugFix=E3=80=91fix=20err=20of=20api=20?= =?UTF-8?q?`to=5Ftensor`,=20which=20caused=20by=20numpy=20version=20update?= =?UTF-8?q?=20(#53534)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix * update code * pre-commit * remove scale check (0-D tensor is usable) * fix data dtype err * fix numpy default dtype diff * fix data dtype * fix data dtype * update * fix coverage --- python/paddle/tensor/creation.py | 62 +++++++++++++----------- test/dygraph_to_static/test_to_tensor.py | 16 ++++++ 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index ee546ffe17e..6fe3f97cd65 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -646,8 +646,10 @@ def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True): def _to_tensor_static(data, dtype=None, stop_gradient=None): - if isinstance(data, Variable) and (dtype is None or dtype == data.dtype): + if isinstance(data, Variable): output = data + if dtype is not None and dtype != data.dtype: + output = paddle.cast(output, dtype) else: if isinstance(data, np.number): # Special case for numpy scalars data = np.array(data) @@ -656,17 +658,37 @@ def _to_tensor_static(data, dtype=None, stop_gradient=None): if np.isscalar(data) and not isinstance(data, str): data = np.array(data) elif isinstance(data, (list, tuple)): - data = np.array(data) + try: + ''' + In numpy version >= 1.24.0, case like: + np.array([Variable, 1, 2]) + is not supported, it will raise error (numpy returns an numpy array with dtype='object' in version <= 1.23.5) + + Thus, process nested structure in except block + ''' + data = np.array(data) + + # for numpy version <= 1.23.5 + if data.dtype == 'object': + raise RuntimeError("Numpy get dtype `object`.") + + except: + to_stack_list = [None] * len(data) + for idx, d in enumerate(data): + to_stack_list[idx] = _to_tensor_static( + d, dtype, stop_gradient + ) + data = paddle.stack(to_stack_list) + data = paddle.squeeze(data, -1) - if ( - isinstance(data, np.ndarray) - and not dtype - and data.dtype != 'object' - ): - if data.dtype in ['float16', 'float32', 'float64']: - data = data.astype(paddle.get_default_dtype()) - elif data.dtype in ['int32']: - data = data.astype('int64') + else: + raise RuntimeError( + f"Do not support transform type `{type(data)}` to tensor" + ) + + # fix numpy default dtype + if data.dtype in ['float16', 'float32', 'float64']: + data = data.astype(paddle.get_default_dtype()) if dtype: target_dtype = dtype @@ -674,24 +696,10 @@ def _to_tensor_static(data, dtype=None, stop_gradient=None): target_dtype = data.dtype else: target_dtype = paddle.get_default_dtype() - target_dtype = convert_dtype(target_dtype) - if ( - isinstance(data, np.ndarray) - and len(data.shape) > 0 - and any(isinstance(x, Variable) for x in data) - ): - to_stack_list = [None] * data.shape[0] - for idx, d in enumerate(data): - to_stack_list[idx] = _to_tensor_static(d, dtype, stop_gradient) - data = paddle.stack(to_stack_list) - data = paddle.squeeze(data, -1) - - if not isinstance(data, Variable): - output = assign(data) - else: - output = data + output = assign(data) + if convert_dtype(output.dtype) != target_dtype: output = paddle.cast(output, target_dtype) diff --git a/test/dygraph_to_static/test_to_tensor.py b/test/dygraph_to_static/test_to_tensor.py index 05cd5ec78f2..e96c5247a0d 100644 --- a/test/dygraph_to_static/test_to_tensor.py +++ b/test/dygraph_to_static/test_to_tensor.py @@ -90,6 +90,11 @@ def case7(x): return a +def case8(x): + a = paddle.to_tensor({1: 1}) + return a + + class TestToTensorReturnVal(unittest.TestCase): def test_to_tensor_badreturn(self): paddle.disable_static() @@ -143,6 +148,17 @@ class TestToTensorReturnVal(unittest.TestCase): self.assertTrue(a.stop_gradient == b.stop_gradient) self.assertTrue(a.place._equals(b.place)) + def test_to_tensor_err_log(self): + paddle.disable_static() + x = paddle.to_tensor([3]) + try: + a = paddle.jit.to_static(case8)(x) + except Exception as e: + self.assertTrue( + "Do not support transform type `` to tensor" + in str(e) + ) + class TestStatic(unittest.TestCase): def test_static(self): -- GitLab