From 94e8fc785dc5776de0b9391b780d03a836665542 Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Fri, 21 Apr 2023 22:27:54 +0800 Subject: [PATCH] [frl_train_eval] add bfloat16 dtype support of to_tensor,due to numpy not support bfloat16 (#53153) --- paddle/phi/kernels/cpu/compare_kernel.cc | 18 ++++++++----- python/paddle/fluid/data_feeder.py | 26 ++++++++++++++++++- .../fluid/tests/unittests/test_var_base.py | 20 ++++++++++++++ python/paddle/tensor/creation.py | 7 ++++- 4 files changed, 63 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc index ae6c3fd5cb0..a4f33496087 100644 --- a/paddle/phi/kernels/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -81,7 +81,8 @@ PD_REGISTER_KERNEL(less_than, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(less_equal, CPU, ALL_LAYOUT, @@ -92,7 +93,8 @@ PD_REGISTER_KERNEL(less_equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(greater_than, CPU, ALL_LAYOUT, @@ -103,7 +105,8 @@ PD_REGISTER_KERNEL(greater_than, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(greater_equal, CPU, ALL_LAYOUT, @@ -114,7 +117,8 @@ PD_REGISTER_KERNEL(greater_equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(equal, CPU, ALL_LAYOUT, @@ -125,7 +129,8 @@ PD_REGISTER_KERNEL(equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(not_equal, CPU, ALL_LAYOUT, @@ -136,7 +141,8 @@ PD_REGISTER_KERNEL(not_equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(equal_all, CPU, diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 876d4772462..1a7ffaf26bf 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -21,6 +21,7 @@ import six from six.moves import zip, range, xrange import multiprocessing import warnings +import struct from .framework import Variable, default_main_program, _current_expected_place, _non_static_mode, _in_eager_without_dygraph_check from .framework import _cpu_num, _cuda_ids @@ -43,6 +44,27 @@ _PADDLE_DTYPE_2_NUMPY_DTYPE = { } +def copy_bits_from_float_to_uint16(f): + return struct.unpack('> 16 + + +def convert_float_to_uint16(data, data_format="NCHW"): + if data.size == 0: + return data.view(np.uint16) + + if data_format == "NHWC": + data = np.transpose(data, [0, 3, 1, 2]) + + new_data = [] + for x in np.nditer(data): + new_data.append(np.uint16(copy_bits_from_float_to_uint16(x))) + new_data = np.reshape(new_data, data.shape).view(np.uint16) + + if data_format == "NHWC": + new_data = np.transpose(new_output, [0, 2, 3, 1]) + return new_data + + def convert_dtype(dtype): if isinstance(dtype, core.VarDesc.VarType): if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE: @@ -68,7 +90,9 @@ def convert_dtype(dtype): # however, jointly supporting python2 and python3, (as well as python4 maybe) # may still be a long-lasting problem. return str(dtype) - # NOTE(zhangbo): Now numpy does not support bfloat, and paddle use uint16 to represent bfloat16, and there binaries are consistent. + # NOTE(zhangbo): Now numpy does not support bfloat, so use numpy.uint16 to represent paddle.bfloat16, there binaries are consistent. + # If cast ndarray to uint16 and trans to tensor, should not ndarray.astype('uint16') directly + # should use function 'convert_float_to_uint16' above, otherwise bits is wrong if dtype in ['bfloat16']: return 'uint16' diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 2d5778b7b20..f7f0ea30f7b 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -244,6 +244,26 @@ class TestVarBase(unittest.TestCase): np.testing.assert_array_equal(x.numpy(), numpy_array) self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR) + # test dtype bfloat16 + x = paddle.to_tensor(-1e6, dtype=paddle.bfloat16) + self.assertEqual(x.dtype, core.VarDesc.VarType.BF16) + self.assertTrue(x == -999424.0) + + x = paddle.to_tensor([-1e6, -1e6, -1e6], dtype='bfloat16') + self.assertEqual(x.dtype, core.VarDesc.VarType.BF16) + self.assertTrue(x[0] == -999424.0) + self.assertTrue(x[1] == -999424.0) + self.assertTrue(x[2] == -999424.0) + + x = paddle.to_tensor( + -1e6, dtype=paddle.bfloat16, stop_gradient=False + ) + self.assertEqual(x.dtype, core.VarDesc.VarType.BF16) + self.assertTrue(x == -999424.0) + y = x * x + y.backward() + self.assertTrue(x.grad == -999424.0 * 2) + with self.assertRaises(ValueError): paddle.randn([3, 2, 2]).item() with self.assertRaises(ValueError): diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 82a016ce64d..6e4337481d7 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -29,6 +29,7 @@ from ..fluid.data_feeder import ( check_type, check_dtype, convert_dtype, + convert_float_to_uint16, ) from ..framework import ( convert_np_dtype_to_dtype_, @@ -399,7 +400,11 @@ def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True): data = data.astype(default_type) if dtype and convert_dtype(dtype) != data.dtype: - data = data.astype(convert_dtype(dtype)) + if convert_dtype(dtype) in ['uint16']: + # should not ndarray.astype('uint16') directly, data bits is wrong + data = convert_float_to_uint16(data.astype('float32')) + else: + data = data.astype(convert_dtype(dtype)) if _in_eager_without_dygraph_check() and isinstance(data, np.ndarray): return core.eager.Tensor( -- GitLab