未验证 提交 2aedd9db 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[BUG] fix paddle.to_tensor/Tensor.item/Tensor.numpy BF16 bug (#53567)

上级 0a59825e
......@@ -1079,13 +1079,11 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self,
T b = paddle::pybind::TensorGetElement<T>(tensor, offset); \
Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank]; \
Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank]; \
py_dims[0] = 1; \
py_strides[0] = 1; \
auto& api = pybind11::detail::npy_api::get(); \
PyObject* array = api.PyArray_NewFromDescr_( \
api.PyArray_Type_, \
api.PyArray_DescrFromType_(numpy_dtype), \
1, \
0, \
py_dims, \
py_strides, \
nullptr, \
......
......@@ -46,10 +46,6 @@ _PADDLE_DTYPE_2_NUMPY_DTYPE = {
}
def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(data, data_format="NCHW"):
if data.size == 0:
return data.view(np.uint16)
......@@ -57,16 +53,25 @@ def convert_float_to_uint16(data, data_format="NCHW"):
if data_format == "NHWC":
data = np.transpose(data, [0, 3, 1, 2])
new_data = []
for x in np.nditer(data):
new_data.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_data = np.reshape(new_data, data.shape).view(np.uint16)
new_data = np.vectorize(
lambda x: struct.unpack('<I', struct.pack('<f', x))[0] >> 16,
otypes=[np.uint16],
)(data.flat)
new_data = np.reshape(new_data, data.shape)
if data_format == "NHWC":
new_data = np.transpose(new_output, [0, 2, 3, 1])
new_data = np.transpose(new_data, [0, 2, 3, 1])
return new_data
def convert_uint16_to_float(data):
new_data = np.vectorize(
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
otypes=[np.float32],
)(data.flat)
return np.reshape(new_data, data.shape)
def convert_dtype(dtype):
if isinstance(dtype, core.VarDesc.VarType):
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
......
......@@ -33,7 +33,10 @@ from ..framework import (
)
from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_tensor
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE
from paddle.fluid.data_feeder import (
convert_uint16_to_float,
_PADDLE_DTYPE_2_NUMPY_DTYPE,
)
import paddle.utils.deprecated as deprecated
import paddle.profiler as profiler
from paddle.profiler.utils import in_profiler_mode
......@@ -614,7 +617,10 @@ def monkey_patch_tensor():
print(x.item(0, 2)) #3.3
"""
return self._getitem_from_offset(*args).item()
scalar = self._getitem_from_offset(*args)
if scalar.dtype == np.uint16:
return convert_uint16_to_float(scalar).item()
return scalar.item()
@property
def inplace_version(self):
......
......@@ -246,10 +246,12 @@ class TestVarBase(unittest.TestCase):
np.testing.assert_array_equal(x.numpy(), numpy_array)
self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR)
# test dtype bfloat16
# test dtype=bfloat16
x = paddle.to_tensor(-1e6, dtype=paddle.bfloat16)
self.assertEqual(x.dtype, core.VarDesc.VarType.BF16)
self.assertTrue(x == -999424.0)
self.assertTrue(x.item() == -999424.0)
self.assertTrue(isinstance(x.item(), float))
x = paddle.to_tensor([-1e6, -1e6, -1e6], dtype='bfloat16')
self.assertEqual(x.dtype, core.VarDesc.VarType.BF16)
......@@ -266,6 +268,28 @@ class TestVarBase(unittest.TestCase):
y.backward()
self.assertTrue(x.grad == -999424.0 * 2)
# test default_type=bfloat16
paddle.set_default_dtype('bfloat16')
x = paddle.to_tensor(-1e6)
self.assertEqual(x.dtype, core.VarDesc.VarType.BF16)
self.assertTrue(x == -999424.0)
self.assertTrue(x.item() == -999424.0)
self.assertTrue(isinstance(x.item(), float))
x = paddle.to_tensor([-1e6, -1e6, -1e6])
self.assertEqual(x.dtype, core.VarDesc.VarType.BF16)
self.assertTrue(x[0] == -999424.0)
self.assertTrue(x[1] == -999424.0)
self.assertTrue(x[2] == -999424.0)
x = paddle.to_tensor(-1e6, stop_gradient=False)
self.assertEqual(x.dtype, core.VarDesc.VarType.BF16)
self.assertTrue(x == -999424.0)
y = x * x
y.backward()
self.assertTrue(x.grad == -999424.0 * 2)
paddle.set_default_dtype('float32')
with self.assertRaises(ValueError):
paddle.randn([3, 2, 2]).item()
with self.assertRaises(ValueError):
......
......@@ -40,7 +40,7 @@ def set_default_dtype(d):
"""
if isinstance(d, type):
# This branch is for NumPy scalar types
# This branch is for np.dtype
if d in [np.float16, np.float32, np.float64]:
d = d.__name__
else:
......@@ -49,7 +49,7 @@ def set_default_dtype(d):
", but received %s" % d.__name__
)
else:
# This branch is for np.dtype and str
# This branch is for str
if d in ['float16', 'float32', 'float64', 'bfloat16']:
# NOTE(SigureMo): Since the np.dtype object is not an instance of
# type, so it will not be handled by the previous branch. We need
......
......@@ -83,8 +83,6 @@ class ProgressBar:
if k == "loss":
if isinstance(val, list):
scalar_val = val[0]
elif isinstance(val, np.ndarray):
scalar_val = val.item()
else:
scalar_val = val
if isinstance(scalar_val, np.uint16):
......
......@@ -542,18 +542,28 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None):
def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
def _handle_tensor_dtype(tensor, dtype):
if dtype:
if convert_dtype(dtype) != convert_dtype(tensor.dtype):
return tensor.astype(convert_dtype(dtype))
return tensor
def _handle_np_dtype(ndarray, dtype):
if dtype:
if convert_dtype(dtype) != convert_dtype(ndarray.dtype):
# should not ndarray.astype('uint16') directly, data bits is wrong
if convert_dtype(dtype) in ['uint16']:
return convert_float_to_uint16(ndarray.astype('float32'))
else:
return ndarray.astype(convert_dtype(dtype))
return ndarray
if isinstance(data, np.number): # Special case for numpy scalars
data = np.array(data)
if not isinstance(data, np.ndarray):
def _handle_dtype(data, dtype):
if dtype:
if convert_dtype(dtype) != convert_dtype(data.dtype):
return data.astype(convert_dtype(dtype))
return data
if np.isscalar(data) and not isinstance(data, str):
data = np.array(data)
elif isinstance(data, (list, tuple)):
......@@ -565,12 +575,12 @@ def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
)
elif isinstance(data, paddle.Tensor) and not in_dygraph_mode():
data = data._copy_to(place, False)
data = _handle_dtype(data, dtype)
data = _handle_tensor_dtype(data, dtype)
data.stop_gradient = stop_gradient
return data
elif isinstance(data, core.eager.Tensor) and in_dygraph_mode():
data = data._copy_to(place, False)
data = _handle_dtype(data, dtype)
data = _handle_tensor_dtype(data, dtype)
data.stop_gradient = stop_gradient
return data
elif isinstance(data, (core.LoDTensor, core.Tensor)):
......@@ -583,7 +593,7 @@ def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
data = paddle.Tensor(data)
if not data.place._equals(place):
data = data._copy_to(place, False)
data = _handle_dtype(data, dtype)
data = _handle_tensor_dtype(data, dtype)
data.stop_gradient = stop_gradient
return data
else:
......@@ -607,18 +617,13 @@ def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
if default_type in ['float16', 'float32']
else 'complex128'
)
data = data.astype(default_type)
data = _handle_np_dtype(data, default_type)
# Windows default type is 'int32', while Linux/Mac is 'int64'. Unify they.
if data.dtype in ['int32']:
default_type = "int64"
data = data.astype(default_type)
data = data.astype("int64")
if dtype and convert_dtype(dtype) != data.dtype:
if convert_dtype(dtype) in ['uint16']:
# should not ndarray.astype('uint16') directly, data bits is wrong
data = convert_float_to_uint16(data.astype('float32'))
else:
data = data.astype(convert_dtype(dtype))
if dtype:
data = _handle_np_dtype(data, dtype)
if _in_eager_without_dygraph_check() and isinstance(data, np.ndarray):
return core.eager.Tensor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册