未验证 提交 94e8fc78 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[frl_train_eval] add bfloat16 dtype support of to_tensor,due to numpy not support bfloat16 (#53153)

上级 c47853f6
......@@ -81,7 +81,8 @@ PD_REGISTER_KERNEL(less_than,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(less_equal,
CPU,
ALL_LAYOUT,
......@@ -92,7 +93,8 @@ PD_REGISTER_KERNEL(less_equal,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(greater_than,
CPU,
ALL_LAYOUT,
......@@ -103,7 +105,8 @@ PD_REGISTER_KERNEL(greater_than,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(greater_equal,
CPU,
ALL_LAYOUT,
......@@ -114,7 +117,8 @@ PD_REGISTER_KERNEL(greater_equal,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(equal,
CPU,
ALL_LAYOUT,
......@@ -125,7 +129,8 @@ PD_REGISTER_KERNEL(equal,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(not_equal,
CPU,
ALL_LAYOUT,
......@@ -136,7 +141,8 @@ PD_REGISTER_KERNEL(not_equal,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(equal_all,
CPU,
......
......@@ -21,6 +21,7 @@ import six
from six.moves import zip, range, xrange
import multiprocessing
import warnings
import struct
from .framework import Variable, default_main_program, _current_expected_place, _non_static_mode, _in_eager_without_dygraph_check
from .framework import _cpu_num, _cuda_ids
......@@ -43,6 +44,27 @@ _PADDLE_DTYPE_2_NUMPY_DTYPE = {
}
def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(data, data_format="NCHW"):
if data.size == 0:
return data.view(np.uint16)
if data_format == "NHWC":
data = np.transpose(data, [0, 3, 1, 2])
new_data = []
for x in np.nditer(data):
new_data.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_data = np.reshape(new_data, data.shape).view(np.uint16)
if data_format == "NHWC":
new_data = np.transpose(new_output, [0, 2, 3, 1])
return new_data
def convert_dtype(dtype):
if isinstance(dtype, core.VarDesc.VarType):
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
......@@ -68,7 +90,9 @@ def convert_dtype(dtype):
# however, jointly supporting python2 and python3, (as well as python4 maybe)
# may still be a long-lasting problem.
return str(dtype)
# NOTE(zhangbo): Now numpy does not support bfloat, and paddle use uint16 to represent bfloat16, and there binaries are consistent.
# NOTE(zhangbo): Now numpy does not support bfloat, so use numpy.uint16 to represent paddle.bfloat16, there binaries are consistent.
# If cast ndarray to uint16 and trans to tensor, should not ndarray.astype('uint16') directly
# should use function 'convert_float_to_uint16' above, otherwise bits is wrong
if dtype in ['bfloat16']:
return 'uint16'
......
......@@ -244,6 +244,26 @@ class TestVarBase(unittest.TestCase):
np.testing.assert_array_equal(x.numpy(), numpy_array)
self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR)
# test dtype bfloat16
x = paddle.to_tensor(-1e6, dtype=paddle.bfloat16)
self.assertEqual(x.dtype, core.VarDesc.VarType.BF16)
self.assertTrue(x == -999424.0)
x = paddle.to_tensor([-1e6, -1e6, -1e6], dtype='bfloat16')
self.assertEqual(x.dtype, core.VarDesc.VarType.BF16)
self.assertTrue(x[0] == -999424.0)
self.assertTrue(x[1] == -999424.0)
self.assertTrue(x[2] == -999424.0)
x = paddle.to_tensor(
-1e6, dtype=paddle.bfloat16, stop_gradient=False
)
self.assertEqual(x.dtype, core.VarDesc.VarType.BF16)
self.assertTrue(x == -999424.0)
y = x * x
y.backward()
self.assertTrue(x.grad == -999424.0 * 2)
with self.assertRaises(ValueError):
paddle.randn([3, 2, 2]).item()
with self.assertRaises(ValueError):
......
......@@ -29,6 +29,7 @@ from ..fluid.data_feeder import (
check_type,
check_dtype,
convert_dtype,
convert_float_to_uint16,
)
from ..framework import (
convert_np_dtype_to_dtype_,
......@@ -399,7 +400,11 @@ def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
data = data.astype(default_type)
if dtype and convert_dtype(dtype) != data.dtype:
data = data.astype(convert_dtype(dtype))
if convert_dtype(dtype) in ['uint16']:
# should not ndarray.astype('uint16') directly, data bits is wrong
data = convert_float_to_uint16(data.astype('float32'))
else:
data = data.astype(convert_dtype(dtype))
if _in_eager_without_dygraph_check() and isinstance(data, np.ndarray):
return core.eager.Tensor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册