From 1a13626f5f5c334433b3051fec6eeca15c4942ab Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 26 Jan 2021 15:16:57 +0800 Subject: [PATCH] polish printing dtype (#30682) * polish printing dtype * fix special case --- python/paddle/fluid/data_feeder.py | 38 ++++++++----------- .../fluid/dygraph/varbase_patch_methods.py | 16 ++++++++ .../fluid/tests/unittests/test_var_base.py | 10 +++++ 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 8a68ad9d54..b2db00296b 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -26,31 +26,25 @@ from .framework import Variable, default_main_program, _current_expected_place, from .framework import _cpu_num, _cuda_ids __all__ = ['DataFeeder'] +_PADDLE_DTYPE_2_NUMPY_DTYPE = { + core.VarDesc.VarType.BOOL: 'bool', + core.VarDesc.VarType.FP16: 'float16', + core.VarDesc.VarType.FP32: 'float32', + core.VarDesc.VarType.FP64: 'float64', + core.VarDesc.VarType.INT8: 'int8', + core.VarDesc.VarType.INT16: 'int16', + core.VarDesc.VarType.INT32: 'int32', + core.VarDesc.VarType.INT64: 'int64', + core.VarDesc.VarType.UINT8: 'uint8', + core.VarDesc.VarType.COMPLEX64: 'complex64', + core.VarDesc.VarType.COMPLEX128: 'complex128', +} + def convert_dtype(dtype): if isinstance(dtype, core.VarDesc.VarType): - if dtype == core.VarDesc.VarType.BOOL: - return 'bool' - elif dtype == core.VarDesc.VarType.FP16: - return 'float16' - elif dtype == core.VarDesc.VarType.FP32: - return 'float32' - elif dtype == core.VarDesc.VarType.FP64: - return 'float64' - elif dtype == core.VarDesc.VarType.INT8: - return 'int8' - elif dtype == core.VarDesc.VarType.INT16: - return 'int16' - elif dtype == core.VarDesc.VarType.INT32: - return 'int32' - elif dtype == core.VarDesc.VarType.INT64: - return 'int64' - elif dtype == core.VarDesc.VarType.UINT8: - return 'uint8' - elif dtype == core.VarDesc.VarType.COMPLEX64: - return 'complex64' - elif dtype == core.VarDesc.VarType.COMPLEX128: - return 'complex128' + if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE: + return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] elif isinstance(dtype, type): if dtype in [ np.bool, np.float16, np.float32, np.float64, np.int8, np.int16, diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 7b0a3453b1..d3cf4d7bf3 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -23,6 +23,7 @@ from ..framework import Variable, Parameter, ParamBase from .base import switch_to_static_graph from .math_op_patch import monkey_patch_math_varbase from .parallel import scale_loss +from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE def monkey_patch_varbase(): @@ -319,5 +320,20 @@ def monkey_patch_varbase(): ("__name__", "Tensor")): setattr(core.VarBase, method_name, method) + # NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class. + # So, we need to overwrite it to a more readable one. + # See details in https://github.com/pybind/pybind11/issues/2537. + origin = getattr(core.VarDesc.VarType, "__repr__") + + def dtype_str(dtype): + if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE: + prefix = 'paddle.' + return prefix + _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] + else: + # for example, paddle.fluid.core.VarDesc.VarType.LOD_TENSOR + return origin(dtype) + + setattr(core.VarDesc.VarType, "__repr__", dtype_str) + # patch math methods for varbase monkey_patch_math_varbase() diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 2f4a9c8e37..6c5458c1a2 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -617,6 +617,16 @@ class TestVarBase(unittest.TestCase): self.assertEqual(a_str, expected) paddle.enable_static() + def test_print_tensor_dtype(self): + paddle.disable_static(paddle.CPUPlace()) + a = paddle.rand([1]) + a_str = str(a.dtype) + + expected = 'paddle.float32' + + self.assertEqual(a_str, expected) + paddle.enable_static() + class TestVarBaseSetitem(unittest.TestCase): def setUp(self): -- GitLab