未验证 提交 049696bf 编写于 作者: L Leo Chen 提交者: GitHub

Refine the format of printing tensor (#27673)

* add sumary feature

* refine printting tensor

* add sci_mode

* add sample code

* fix indent error

* fix _format_item

* polish code

* support item indent

* add ut

* set place for ut

* fix py2 issue

* fix ut
上级 c90d3556
......@@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include <algorithm>
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -943,6 +946,12 @@ void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst) {
#endif
}
template <typename T>
std::string format_tensor(const framework::Tensor& tensor) {
// TODO(zhiqiu): use the print option to format tensor.
return "NOT IMPLEMENTED";
}
template <typename T>
std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<T>();
......
......@@ -25,6 +25,26 @@ limitations under the License. */
namespace paddle {
namespace framework {
class PrintOptions {
public:
static PrintOptions& Instance() {
static PrintOptions instance;
return instance;
}
~PrintOptions() {}
PrintOptions(const PrintOptions& o) = delete;
const PrintOptions& operator=(const PrintOptions& o) = delete;
int precision = 8;
int threshold = 1000;
int edgeitems = 3;
int linewidth = 75;
bool sci_mode = false;
private:
PrintOptions() {}
};
// NOTE(zcd): Because TensorCopy is an async operation, when the src_place
// and dst_place are two different GPU, to ensure that the operation can
// be carried out correctly, there is a src_ctx wait operation in TensorCopy.
......
......@@ -17,10 +17,12 @@
//
#include <paddle/fluid/framework/op_registry.h>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/tracer.h"
......@@ -286,9 +288,9 @@ TEST(test_tracer, test_unique_name_generator) {
ASSERT_STREQ("fc_1", fc_2.c_str());
// use `eager_tmp` as key if not specify it.
auto tmp_var_2 = tracer.GenerateUniqueName();
ASSERT_STREQ("eager_tmp_2", tmp_var_2.c_str());
auto tmp_var_3 = tracer.GenerateUniqueName("eager_tmp");
ASSERT_STREQ("eager_tmp_3", tmp_var_3.c_str());
ASSERT_STREQ("dygraph_tmp_2", tmp_var_2.c_str());
auto tmp_var_3 = tracer.GenerateUniqueName("dygraph_tmp");
ASSERT_STREQ("dygraph_tmp_3", tmp_var_3.c_str());
}
TEST(test_tracer, test_current_tracer) {
......
......@@ -20,6 +20,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/jit/program_desc_tracer.h"
......@@ -32,7 +33,7 @@ namespace imperative {
class UniqueNameGenerator {
public:
explicit UniqueNameGenerator(std::string prefix = "") : prefix_(prefix) {}
std::string Generate(std::string key = "eager_tmp") {
std::string Generate(std::string key = "dygraph_tmp") {
return prefix_ + key + "_" + std::to_string(id_++);
}
......@@ -83,7 +84,7 @@ class Tracer {
// name like `tmp_0` in some cases when transform dygraph into static layers.
// So we modify the default prefix key into `eager_tmp` to distinguish with
// static graph.
std::string GenerateUniqueName(std::string key = "eager_tmp") {
std::string GenerateUniqueName(std::string key = "dygraph_tmp") {
return generator_->Generate(key);
}
......
......@@ -833,6 +833,12 @@ void BindImperative(py::module *m_ptr) {
.def_property_readonly(
"place", [](imperative::VarBase &self) { return self.Place(); },
py::return_value_policy::copy)
.def_property_readonly("_place_str",
[](imperative::VarBase &self) {
std::stringstream ostr;
ostr << self.Place();
return ostr.str();
})
.def_property_readonly("type", &imperative::VarBase::Type)
.def_property_readonly("dtype", &imperative::VarBase::DataType);
......@@ -890,7 +896,7 @@ void BindImperative(py::module *m_ptr) {
&imperative::Tracer::GetProgramDescTracer,
py::return_value_policy::reference)
.def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
py::arg("key") = "eager_tmp")
py::arg("key") = "dygraph_tmp")
.def(
"_set_amp_op_list",
[](imperative::Tracer &self,
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
......@@ -45,6 +46,7 @@ limitations under the License. */
#include "paddle/fluid/framework/save_load_util.h"
#include "paddle/fluid/framework/scope_pool.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/version.h"
......@@ -440,6 +442,31 @@ PYBIND11_MODULE(core_noavx, m) {
&pb_vmap);
});
m.def("set_printoptions", [](const py::kwargs &kwargs) {
auto &print_opt = framework::PrintOptions::Instance();
if (kwargs.contains("precision")) {
print_opt.precision = kwargs["precision"].cast<int>();
}
if (kwargs.contains("threshold")) {
print_opt.threshold = kwargs["threshold"].cast<int>();
}
if (kwargs.contains("edgeitems")) {
print_opt.edgeitems = kwargs["edgeitems"].cast<int>();
}
if (kwargs.contains("linewidth")) {
print_opt.linewidth = kwargs["linewidth"].cast<int>();
}
if (kwargs.contains("sci_mode")) {
print_opt.sci_mode = kwargs["sci_mode"].cast<bool>();
}
VLOG(4) << "Set printoptions: precision=" << print_opt.precision
<< ", threshold=" << print_opt.threshold
<< ", edgeitems=" << print_opt.edgeitems
<< ", linewidth=" << print_opt.linewidth
<< ", sci_mode=" << print_opt.sci_mode;
});
m.def(
"_append_python_callable_object_and_return_id",
[](py::object py_obj) -> size_t {
......@@ -629,6 +656,8 @@ PYBIND11_MODULE(core_noavx, m) {
.def("_get_double_element", TensorGetElement<double>)
.def("_place", [](Tensor &self) { return self.place(); })
.def("_dtype", [](Tensor &self) { return self.type(); })
.def("_layout",
[](Tensor &self) { return DataLayoutToString(self.layout()); })
.def("_share_data_with", &Tensor::ShareDataWith)
.def("__getitem__", PySliceTensor, py::return_value_policy::reference)
.def("__str__", [](const Tensor &self) {
......
......@@ -218,6 +218,9 @@ from .tensor.search import where #DEFINE_ALIAS
from .tensor.search import index_select #DEFINE_ALIAS
from .tensor.search import nonzero #DEFINE_ALIAS
from .tensor.search import sort #DEFINE_ALIAS
from .tensor.to_string import set_printoptions
from .framework.random import manual_seed #DEFINE_ALIAS
from .framework.random import get_cuda_rng_state #DEFINE_ALIAS
from .framework.random import set_cuda_rng_state #DEFINE_ALIAS
......
......@@ -236,22 +236,15 @@ def monkey_patch_varbase():
.. code-block:: python
import paddle
paddle.disable_static()
x = paddle.rand([1, 5])
x = paddle.rand([2, 5])
print(x)
# Variable: eager_tmp_0
# - place: CUDAPlace(0)
# - shape: [1, 5]
# - layout: NCHW
# - dtype: float
# - data: [0.645307 0.597973 0.732793 0.646921 0.540328]
paddle.enable_static()
# Tensor(shape=[2, 5], dtype=float32, place=CPUPlace,
# [[0.30574632, 0.55739117, 0.30902600, 0.39413780, 0.44830436],
# [0.79010487, 0.53972793, 0.09495186, 0.44267157, 0.72112119]])
"""
tensor = self.value().get_tensor()
if tensor._is_initialized():
return 'Tensor: %s\n%s' % (self.name, str(tensor))
else:
return 'Tensor: %s, not initialized' % (self.name)
from paddle.tensor.to_string import to_string
return to_string(self)
@property
def block(self):
......
......@@ -5310,8 +5310,8 @@ class ParamBase(core.VarBase):
# - data: [...]
paddle.enable_static()
"""
return "Parameter containing:\n {}\n - stop_gradient: {}".format(
super(ParamBase, self).__str__(), self.stop_gradient)
return "Parameter containing:\n{tensor}".format(
tensor=super(ParamBase, self).__str__())
__repr__ = __str__
......
......@@ -50,10 +50,10 @@ class TestImperativeUniqueName(unittest.TestCase):
with fluid.dygraph.guard():
tracer = fluid.framework._dygraph_tracer()
tmp_var_0 = tracer._generate_unique_name()
self.assertEqual(tmp_var_0, "eager_tmp_0")
self.assertEqual(tmp_var_0, "dygraph_tmp_0")
tmp_var_1 = tracer._generate_unique_name("eager_tmp")
self.assertEqual(tmp_var_1, "eager_tmp_1")
tmp_var_1 = tracer._generate_unique_name("dygraph_tmp")
self.assertEqual(tmp_var_1, "dygraph_tmp_1")
if __name__ == '__main__':
......
......@@ -404,6 +404,36 @@ class TestVarBase(unittest.TestCase):
self.assertListEqual(list(var_base.shape), list(static_var.shape))
def test_tensor_str(self):
paddle.disable_static(paddle.CPUPlace())
paddle.manual_seed(10)
a = paddle.rand([10, 20])
paddle.set_printoptions(4, 100, 3)
a_str = str(a)
if six.PY2:
expected = '''Tensor(shape=[10L, 20L], dtype=float32, place=CPUPlace, stop_gradient=True,
[[0.2727, 0.5489, 0.8655, ..., 0.2916, 0.8525, 0.9000],
[0.3806, 0.8996, 0.0928, ..., 0.9535, 0.8378, 0.6409],
[0.1484, 0.4038, 0.8294, ..., 0.0148, 0.6520, 0.4250],
...,
[0.3426, 0.1909, 0.7240, ..., 0.4218, 0.2676, 0.5679],
[0.5561, 0.2081, 0.0676, ..., 0.9778, 0.3302, 0.9559],
[0.2665, 0.8483, 0.5389, ..., 0.4956, 0.6862, 0.9178]])'''
else:
expected = '''Tensor(shape=[10, 20], dtype=float32, place=CPUPlace, stop_gradient=True,
[[0.2727, 0.5489, 0.8655, ..., 0.2916, 0.8525, 0.9000],
[0.3806, 0.8996, 0.0928, ..., 0.9535, 0.8378, 0.6409],
[0.1484, 0.4038, 0.8294, ..., 0.0148, 0.6520, 0.4250],
...,
[0.3426, 0.1909, 0.7240, ..., 0.4218, 0.2676, 0.5679],
[0.5561, 0.2081, 0.0676, ..., 0.9778, 0.3302, 0.9559],
[0.2665, 0.8483, 0.5389, ..., 0.4956, 0.6862, 0.9178]])'''
self.assertEqual(a_str, expected)
paddle.enable_static()
class TestVarBaseSetitem(unittest.TestCase):
def setUp(self):
......
......@@ -193,3 +193,4 @@ from .stat import numel #DEFINE_ALIAS
# from .tensor import Tensor #DEFINE_ALIAS
# from .tensor import LoDTensor #DEFINE_ALIAS
# from .tensor import LoDTensorArray #DEFINE_ALIAS
from .to_string import set_printoptions
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from paddle.fluid.layers import core
from paddle.fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
__all__ = ['set_printoptions']
class PrintOptions(object):
precision = 8
threshold = 1000
edgeitems = 3
linewidth = 80
sci_mode = False
DEFAULT_PRINT_OPTIONS = PrintOptions()
def set_printoptions(precision=None,
threshold=None,
edgeitems=None,
sci_mode=None):
"""Set the printing options for Tensor.
NOTE: The function is similar with numpy.set_printoptions()
Args:
precision (int, optional): Number of digits of the floating number, default 8.
threshold (int, optional): Total number of elements printed, default 1000.
edgeitems (int, optional): Number of elements in summary at the begining and end of each dimension, defalt 3.
sci_mode (bool, optional): Format the floating number with scientific notation or not, default False.
Returns:
None.
Examples:
.. code-block:: python
import paddle
paddle.manual_seed(10)
a = paddle.rand([10, 20])
paddle.set_printoptions(4, 100, 3)
print(a)
'''
Tensor: dygraph_tmp_0
- place: CPUPlace
- shape: [10, 20]
- layout: NCHW
- dtype: float32
- data: [[0.2727, 0.5489, 0.8655, ..., 0.2916, 0.8525, 0.9000],
[0.3806, 0.8996, 0.0928, ..., 0.9535, 0.8378, 0.6409],
[0.1484, 0.4038, 0.8294, ..., 0.0148, 0.6520, 0.4250],
...,
[0.3426, 0.1909, 0.7240, ..., 0.4218, 0.2676, 0.5679],
[0.5561, 0.2081, 0.0676, ..., 0.9778, 0.3302, 0.9559],
[0.2665, 0.8483, 0.5389, ..., 0.4956, 0.6862, 0.9178]]
'''
"""
kwargs = {}
if precision is not None:
check_type(precision, 'precision', (int), 'set_printoptions')
DEFAULT_PRINT_OPTIONS.precision = precision
kwargs['precision'] = precision
if threshold is not None:
check_type(threshold, 'threshold', (int), 'set_printoptions')
DEFAULT_PRINT_OPTIONS.threshold = threshold
kwargs['threshold'] = threshold
if edgeitems is not None:
check_type(edgeitems, 'edgeitems', (int), 'set_printoptions')
DEFAULT_PRINT_OPTIONS.edgeitems = edgeitems
kwargs['edgeitems'] = edgeitems
if sci_mode is not None:
check_type(sci_mode, 'sci_mode', (bool), 'set_printoptions')
DEFAULT_PRINT_OPTIONS.sci_mode = sci_mode
kwargs['sci_mode'] = sci_mode
#TODO(zhiqiu): support linewidth
core.set_printoptions(**kwargs)
def _to_sumary(var):
edgeitems = DEFAULT_PRINT_OPTIONS.edgeitems
if len(var.shape) == 0:
return var
elif len(var.shape) == 1:
if var.shape[0] > 2 * edgeitems:
return paddle.concat([var[:edgeitems], var[-edgeitems:]])
else:
return var
else:
# recursively handle all dimensions
if var.shape[0] > 2 * edgeitems:
begin = [x for x in var[:edgeitems]]
end = [x for x in var[-edgeitems:]]
return paddle.stack([_to_sumary(x) for x in (begin + end)])
else:
return paddle.stack([_to_sumary(x) for x in var])
def _format_item(np_var, max_width=0):
if np_var.dtype == np.float32 or np_var.dtype == np.float64 or np_var.dtype == np.float16:
if DEFAULT_PRINT_OPTIONS.sci_mode:
item_str = '{{:.{}e}}'.format(
DEFAULT_PRINT_OPTIONS.precision).format(np_var)
elif np.ceil(np_var) == np_var:
item_str = '{:.0f}.'.format(np_var)
else:
item_str = '{{:.{}f}}'.format(
DEFAULT_PRINT_OPTIONS.precision).format(np_var)
else:
item_str = '{}'.format(np_var)
if max_width > len(item_str):
return '{indent}{data}'.format(
indent=(max_width - len(item_str)) * ' ', data=item_str)
else:
return item_str
def _get_max_width(var):
max_width = 0
for item in np.nditer(var.numpy()):
item_str = _format_item(item)
max_width = max(max_width, len(item_str))
return max_width
def _format_tensor(var, sumary, indent=0):
edgeitems = DEFAULT_PRINT_OPTIONS.edgeitems
max_width = _get_max_width(_to_sumary(var))
if len(var.shape) == 0:
return _format_item(var.numpy.items(0), max_width)
elif len(var.shape) == 1:
if sumary and var.shape[0] > 2 * edgeitems:
items = [
_format_item(item, max_width)
for item in list(var.numpy())[:DEFAULT_PRINT_OPTIONS.edgeitems]
] + ['...'] + [
_format_item(item, max_width)
for item in list(var.numpy())[-DEFAULT_PRINT_OPTIONS.edgeitems:]
]
else:
items = [
_format_item(item, max_width) for item in list(var.numpy())
]
s = ', '.join(items)
return '[' + s + ']'
else:
# recursively handle all dimensions
if sumary and var.shape[0] > 2 * edgeitems:
vars = [
_format_tensor(x, sumary, indent + 1) for x in var[:edgeitems]
] + ['...'] + [
_format_tensor(x, sumary, indent + 1) for x in var[-edgeitems:]
]
else:
vars = [_format_tensor(x, sumary, indent + 1) for x in var]
return '[' + (',' + '\n' * (len(var.shape) - 1) + ' ' *
(indent + 1)).join(vars) + ']'
def to_string(var, prefix='Tensor'):
indent = len(prefix) + 1
_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"
tensor = var.value().get_tensor()
if not tensor._is_initialized():
return "Tensor(Not initialized)"
if len(var.shape) == 0:
size = 0
else:
size = 1
for dim in var.shape:
size *= dim
sumary = False
if size > DEFAULT_PRINT_OPTIONS.threshold:
sumary = True
data = _format_tensor(var, sumary, indent=indent)
return _template.format(
prefix=prefix,
shape=var.shape,
dtype=convert_dtype(var.dtype),
place=var._place_str,
stop_gradient=var.stop_gradient,
indent=' ' * indent,
data=data)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册