未验证 提交 543ff333 编写于 作者: L Leo Chen 提交者: GitHub

Refine the format of printing tensor 3 (support scaler tensor) (#28544)

上级 1bf48365
......@@ -466,6 +466,17 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_scaler(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.to_tensor(np.array(False))
a_str = str(a)
expected = '''Tensor(shape=[], dtype=bool, place=CPUPlace, stop_gradient=True,
False)'''
self.assertEqual(a_str, expected)
paddle.enable_static()
class TestVarBaseSetitem(unittest.TestCase):
def setUp(self):
......
......@@ -153,7 +153,7 @@ def _format_tensor(var, sumary, indent=0, max_width=0, signed=False):
if len(var.shape) == 0:
# currently, shape = [], i.e., scaler tensor is not supported.
# If it is supported, it should be formatted like this.
return _format_item(var.item(0), max_width, signed)
return _format_item(var, max_width, signed)
elif len(var.shape) == 1:
if sumary and var.shape[0] > 2 * edgeitems:
items = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册