未验证 提交 f57b21e6 编写于 作者: L Leo Chen 提交者: GitHub

[bf16] support printing bf16 tensor (#39375)

上级 eacfc1eb
......@@ -1094,6 +1094,20 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_bf16(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.to_tensor([[1.5, 1.0], [0, 0]])
a = paddle.cast(a, dtype=core.VarDesc.VarType.BF16)
paddle.set_printoptions(precision=4)
a_str = str(a)
expected = '''Tensor(shape=[2, 2], dtype=bfloat16, place=Place(cpu), stop_gradient=True,
[[1.5000, 1. ],
[0. , 0. ]])'''
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_print_tensor_dtype(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.rand([1])
......
......@@ -223,12 +223,18 @@ def _format_tensor(var, summary, indent=0, max_width=0, signed=False):
def to_string(var, prefix='Tensor'):
indent = len(prefix) + 1
dtype = convert_dtype(var.dtype)
if var.dtype == core.VarDesc.VarType.BF16:
dtype = 'bfloat16'
_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"
tensor = var.value().get_tensor()
if not tensor._is_initialized():
return "Tensor(Not initialized)"
if var.dtype == core.VarDesc.VarType.BF16:
var = var.astype('float32')
np_var = var.numpy()
if len(var.shape) == 0:
......@@ -250,7 +256,7 @@ def to_string(var, prefix='Tensor'):
return _template.format(
prefix=prefix,
shape=var.shape,
dtype=convert_dtype(var.dtype),
dtype=dtype,
place=var._place_str,
stop_gradient=var.stop_gradient,
indent=' ' * indent,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册