From 543ff333cdf1434fa8ba77ed84f88b6db7c75b5b Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 11 Nov 2020 20:25:10 +0800 Subject: [PATCH] Refine the format of printing tensor 3 (support scaler tensor) (#28544) --- python/paddle/fluid/tests/unittests/test_var_base.py | 11 +++++++++++ python/paddle/tensor/to_string.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 41aef68db6..511813fc1c 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -466,6 +466,17 @@ class TestVarBase(unittest.TestCase): self.assertEqual(a_str, expected) paddle.enable_static() + def test_tensor_str_scaler(self): + paddle.disable_static(paddle.CPUPlace()) + a = paddle.to_tensor(np.array(False)) + a_str = str(a) + + expected = '''Tensor(shape=[], dtype=bool, place=CPUPlace, stop_gradient=True, + False)''' + + self.assertEqual(a_str, expected) + paddle.enable_static() + class TestVarBaseSetitem(unittest.TestCase): def setUp(self): diff --git a/python/paddle/tensor/to_string.py b/python/paddle/tensor/to_string.py index bd956b923a..778a391df6 100644 --- a/python/paddle/tensor/to_string.py +++ b/python/paddle/tensor/to_string.py @@ -153,7 +153,7 @@ def _format_tensor(var, sumary, indent=0, max_width=0, signed=False): if len(var.shape) == 0: # currently, shape = [], i.e., scaler tensor is not supported. # If it is supported, it should be formatted like this. - return _format_item(var.item(0), max_width, signed) + return _format_item(var, max_width, signed) elif len(var.shape) == 1: if sumary and var.shape[0] > 2 * edgeitems: items = [ -- GitLab