未验证 提交 6f0ae156 编写于 作者: P PuQing 提交者: GitHub

[Numpy]Fix NumpyScaler2Tensor dtype error (#50018)

* fix numpyScaler2Tensor type error

* fix to_tensor docs, test=document_fix
上级 03619037
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
DTYPE_MAP = {
paddle.bool: np.bool_,
paddle.int32: np.int32,
paddle.int64: np.int64,
paddle.float16: np.float16,
paddle.float32: np.float32,
paddle.float64: np.float64,
paddle.complex64: np.complex64,
}
class NumpyScaler2Tensor(unittest.TestCase):
def setUp(self):
self.dtype = np.float32
self.x_np = np.array([1], dtype=self.dtype)[0]
def test_dynamic_scaler2tensor(self):
paddle.disable_static()
x = paddle.to_tensor(self.x_np)
self.assertEqual(DTYPE_MAP[x.dtype], self.dtype)
self.assertEqual(x.numpy(), self.x_np)
if self.dtype in [
np.bool_
]: # bool is not supported convert to 0D-Tensor
return
self.assertEqual(len(x.shape), 0)
def test_static_scaler2tensor(self):
if self.dtype in [np.float16, np.complex64]:
return
paddle.enable_static()
x = paddle.to_tensor(self.x_np)
self.assertEqual(DTYPE_MAP[x.dtype], self.dtype)
if self.dtype in [
np.bool_,
np.float64,
]: # bool is not supported convert to 0D-Tensor and float64 not supported in static mode
return
self.assertEqual(len(x.shape), 0)
class NumpyScaler2TensorBool(NumpyScaler2Tensor):
def setUp(self):
self.dtype = np.bool_
self.x_np = np.array([1], dtype=self.dtype)[0]
class NumpyScaler2TensorFloat16(NumpyScaler2Tensor):
def setUp(self):
self.dtype = np.float16
self.x_np = np.array([1], dtype=self.dtype)[0]
class NumpyScaler2TensorFloat64(NumpyScaler2Tensor):
def setUp(self):
self.dtype = np.float64
self.x_np = np.array([1], dtype=self.dtype)[0]
class NumpyScaler2TensorInt32(NumpyScaler2Tensor):
def setUp(self):
self.dtype = np.int32
self.x_np = np.array([1], dtype=self.dtype)[0]
class NumpyScaler2TensorInt64(NumpyScaler2Tensor):
def setUp(self):
self.dtype = np.int64
self.x_np = np.array([1], dtype=self.dtype)[0]
class NumpyScaler2TensorComplex64(NumpyScaler2Tensor):
def setUp(self):
self.dtype = np.complex64
self.x_np = np.array([1], dtype=self.dtype)[0]
......@@ -533,6 +533,9 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None):
def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
if isinstance(data, np.number): # Special case for numpy scalars
data = np.array(data)
if not isinstance(data, np.ndarray):
def _handle_dtype(data, dtype):
......@@ -627,6 +630,8 @@ def _to_tensor_static(data, dtype=None, stop_gradient=None):
if isinstance(data, Variable) and (dtype is None or dtype == data.dtype):
output = data
else:
if isinstance(data, np.number): # Special case for numpy scalars
data = np.array(data)
if not isinstance(data, np.ndarray):
if np.isscalar(data) and not isinstance(data, str):
......@@ -690,6 +695,18 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
If the ``data`` is already a Tensor, copy will be performed and return a new tensor.
If you only want to change stop_gradient property, please call ``Tensor.stop_gradient = stop_gradient`` directly.
.. code-block:: text
We use the dtype conversion rules following this:
Keep dtype
np.number ───────────► paddle.Tensor
(0D-Tensor)
default_dtype
Python Number ───────────────► paddle.Tensor
(1D-Tensor)
Keep dtype
np.ndarray ───────────► paddle.Tensor
Args:
data(scalar|tuple|list|ndarray|Tensor): Initial data for the tensor.
Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册