未验证 提交 f4f823c8 编写于 作者: L Leo Chen 提交者: GitHub

Refine the format of printing tensor 2 (#28216)

* refine format

* update doc

* handle uninitialized tensor

* add ut
上级 95ac49c3
......@@ -444,6 +444,40 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str2(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.to_tensor([[1.5111111, 1.0], [0, 0]])
a_str = str(a)
if six.PY2:
expected = '''Tensor(shape=[2L, 2L], dtype=float32, place=CPUPlace, stop_gradient=True,
[[1.5111, 1. ],
[0. , 0. ]])'''
else:
expected = '''Tensor(shape=[2, 2], dtype=float32, place=CPUPlace, stop_gradient=True,
[[1.5111, 1. ],
[0. , 0. ]])'''
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str3(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.to_tensor([[-1.5111111, 1.0], [0, -0.5]])
a_str = str(a)
if six.PY2:
expected = '''Tensor(shape=[2L, 2L], dtype=float32, place=CPUPlace, stop_gradient=True,
[[-1.5111, 1. ],
[ 0. , -0.5000]])'''
else:
expected = '''Tensor(shape=[2, 2], dtype=float32, place=CPUPlace, stop_gradient=True,
[[-1.5111, 1. ],
[ 0. , -0.5000]])'''
self.assertEqual(a_str, expected)
paddle.enable_static()
class TestVarBaseSetitem(unittest.TestCase):
def setUp(self):
......
......@@ -58,18 +58,14 @@ def set_printoptions(precision=None,
print(a)
'''
Tensor: dygraph_tmp_0
- place: CPUPlace
- shape: [10, 20]
- layout: NCHW
- dtype: float32
- data: [[0.2727, 0.5489, 0.8655, ..., 0.2916, 0.8525, 0.9000],
[0.3806, 0.8996, 0.0928, ..., 0.9535, 0.8378, 0.6409],
[0.1484, 0.4038, 0.8294, ..., 0.0148, 0.6520, 0.4250],
Tensor(shape=[10, 20], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
[[0.0002, 0.8503, 0.0135, ..., 0.9508, 0.2621, 0.6661],
[0.9710, 0.2605, 0.9950, ..., 0.4427, 0.9241, 0.9363],
[0.0948, 0.3226, 0.9955, ..., 0.1198, 0.0889, 0.9231],
...,
[0.3426, 0.1909, 0.7240, ..., 0.4218, 0.2676, 0.5679],
[0.5561, 0.2081, 0.0676, ..., 0.9778, 0.3302, 0.9559],
[0.2665, 0.8483, 0.5389, ..., 0.4956, 0.6862, 0.9178]]
[0.7206, 0.0941, 0.5292, ..., 0.4856, 0.1379, 0.0351],
[0.1745, 0.5621, 0.3602, ..., 0.2998, 0.4011, 0.1764],
[0.0728, 0.7786, 0.0314, ..., 0.2583, 0.1654, 0.0637]])
'''
"""
kwargs = {}
......@@ -101,7 +97,7 @@ def _to_sumary(var):
return var
elif len(var.shape) == 1:
if var.shape[0] > 2 * edgeitems:
return paddle.concat([var[:edgeitems], var[-edgeitems:]])
return np.concatenate([var[:edgeitems], var[-edgeitems:]])
else:
return var
else:
......@@ -109,12 +105,12 @@ def _to_sumary(var):
if var.shape[0] > 2 * edgeitems:
begin = [x for x in var[:edgeitems]]
end = [x for x in var[-edgeitems:]]
return paddle.stack([_to_sumary(x) for x in (begin + end)])
return np.stack([_to_sumary(x) for x in (begin + end)])
else:
return paddle.stack([_to_sumary(x) for x in var])
return np.stack([_to_sumary(x) for x in var])
def _format_item(np_var, max_width=0):
def _format_item(np_var, max_width=0, signed=False):
if np_var.dtype == np.float32 or np_var.dtype == np.float64 or np_var.dtype == np.float16:
if DEFAULT_PRINT_OPTIONS.sci_mode:
item_str = '{{:.{}e}}'.format(
......@@ -128,54 +124,66 @@ def _format_item(np_var, max_width=0):
item_str = '{}'.format(np_var)
if max_width > len(item_str):
return '{indent}{data}'.format(
indent=(max_width - len(item_str)) * ' ', data=item_str)
else:
if signed: # handle sign character for tenosr with negative item
if np_var < 0:
return item_str.ljust(max_width)
else:
return ' ' + item_str.ljust(max_width - 1)
else:
return item_str.ljust(max_width)
else: # used for _get_max_width
return item_str
def _get_max_width(var):
max_width = 0
for item in list(var.numpy().flatten()):
signed = False
for item in list(var.flatten()):
if (not signed) and (item < 0):
signed = True
item_str = _format_item(item)
max_width = max(max_width, len(item_str))
return max_width
return max_width, signed
def _format_tensor(var, sumary, indent=0):
def _format_tensor(var, sumary, indent=0, max_width=0, signed=False):
edgeitems = DEFAULT_PRINT_OPTIONS.edgeitems
max_width = _get_max_width(_to_sumary(var))
if len(var.shape) == 0:
# currently, shape = [], i.e., scaler tensor is not supported.
# If it is supported, it should be formatted like this.
return _format_item(var.numpy().item(0), max_width)
return _format_item(var.item(0), max_width, signed)
elif len(var.shape) == 1:
if sumary and var.shape[0] > 2 * edgeitems:
items = [
_format_item(item, max_width)
for item in list(var.numpy())[:DEFAULT_PRINT_OPTIONS.edgeitems]
_format_item(item, max_width, signed)
for item in list(var)[:DEFAULT_PRINT_OPTIONS.edgeitems]
] + ['...'] + [
_format_item(item, max_width)
for item in list(var.numpy())[-DEFAULT_PRINT_OPTIONS.edgeitems:]
_format_item(item, max_width, signed)
for item in list(var)[-DEFAULT_PRINT_OPTIONS.edgeitems:]
]
else:
items = [
_format_item(item, max_width) for item in list(var.numpy())
_format_item(item, max_width, signed) for item in list(var)
]
s = ', '.join(items)
return '[' + s + ']'
else:
# recursively handle all dimensions
if sumary and var.shape[0] > 2 * edgeitems:
vars = [
_format_tensor(x, sumary, indent + 1) for x in var[:edgeitems]
_format_tensor(x, sumary, indent + 1, max_width, signed)
for x in var[:edgeitems]
] + ['...'] + [
_format_tensor(x, sumary, indent + 1) for x in var[-edgeitems:]
_format_tensor(x, sumary, indent + 1, max_width, signed)
for x in var[-edgeitems:]
]
else:
vars = [_format_tensor(x, sumary, indent + 1) for x in var]
vars = [
_format_tensor(x, sumary, indent + 1, max_width, signed)
for x in var
]
return '[' + (',' + '\n' * (len(var.shape) - 1) + ' ' *
(indent + 1)).join(vars) + ']'
......@@ -190,6 +198,8 @@ def to_string(var, prefix='Tensor'):
if not tensor._is_initialized():
return "Tensor(Not initialized)"
np_var = var.numpy()
if len(var.shape) == 0:
size = 0
else:
......@@ -201,7 +211,10 @@ def to_string(var, prefix='Tensor'):
if size > DEFAULT_PRINT_OPTIONS.threshold:
sumary = True
data = _format_tensor(var, sumary, indent=indent)
max_width, signed = _get_max_width(_to_sumary(np_var))
data = _format_tensor(
np_var, sumary, indent=indent, max_width=max_width, signed=signed)
return _template.format(
prefix=prefix,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册