From f57b21e662d151e999e4707cc922d737f439920c Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 8 Feb 2022 10:54:09 +0800 Subject: [PATCH] [bf16] support printing bf16 tensor (#39375) --- .../paddle/fluid/tests/unittests/test_var_base.py | 14 ++++++++++++++ python/paddle/tensor/to_string.py | 8 +++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index c74dd24b78b..541df6659c2 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -1094,6 +1094,20 @@ class TestVarBase(unittest.TestCase): self.assertEqual(a_str, expected) paddle.enable_static() + def test_tensor_str_bf16(self): + paddle.disable_static(paddle.CPUPlace()) + a = paddle.to_tensor([[1.5, 1.0], [0, 0]]) + a = paddle.cast(a, dtype=core.VarDesc.VarType.BF16) + paddle.set_printoptions(precision=4) + a_str = str(a) + + expected = '''Tensor(shape=[2, 2], dtype=bfloat16, place=Place(cpu), stop_gradient=True, + [[1.5000, 1. ], + [0. , 0. ]])''' + + self.assertEqual(a_str, expected) + paddle.enable_static() + def test_print_tensor_dtype(self): paddle.disable_static(paddle.CPUPlace()) a = paddle.rand([1]) diff --git a/python/paddle/tensor/to_string.py b/python/paddle/tensor/to_string.py index 6fd20457fe6..af0f33f97ab 100644 --- a/python/paddle/tensor/to_string.py +++ b/python/paddle/tensor/to_string.py @@ -223,12 +223,18 @@ def _format_tensor(var, summary, indent=0, max_width=0, signed=False): def to_string(var, prefix='Tensor'): indent = len(prefix) + 1 + dtype = convert_dtype(var.dtype) + if var.dtype == core.VarDesc.VarType.BF16: + dtype = 'bfloat16' + _template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})" tensor = var.value().get_tensor() if not tensor._is_initialized(): return "Tensor(Not initialized)" + if var.dtype == core.VarDesc.VarType.BF16: + var = var.astype('float32') np_var = var.numpy() if len(var.shape) == 0: @@ -250,7 +256,7 @@ def to_string(var, prefix='Tensor'): return _template.format( prefix=prefix, shape=var.shape, - dtype=convert_dtype(var.dtype), + dtype=dtype, place=var._place_str, stop_gradient=var.stop_gradient, indent=' ' * indent, -- GitLab