未验证 提交 1a13626f 编写于 作者: L Leo Chen 提交者: GitHub

polish printing dtype (#30682)

* polish printing dtype

* fix special case
上级 5bf25d1e
......@@ -26,31 +26,25 @@ from .framework import Variable, default_main_program, _current_expected_place,
from .framework import _cpu_num, _cuda_ids
__all__ = ['DataFeeder']
_PADDLE_DTYPE_2_NUMPY_DTYPE = {
core.VarDesc.VarType.BOOL: 'bool',
core.VarDesc.VarType.FP16: 'float16',
core.VarDesc.VarType.FP32: 'float32',
core.VarDesc.VarType.FP64: 'float64',
core.VarDesc.VarType.INT8: 'int8',
core.VarDesc.VarType.INT16: 'int16',
core.VarDesc.VarType.INT32: 'int32',
core.VarDesc.VarType.INT64: 'int64',
core.VarDesc.VarType.UINT8: 'uint8',
core.VarDesc.VarType.COMPLEX64: 'complex64',
core.VarDesc.VarType.COMPLEX128: 'complex128',
}
def convert_dtype(dtype):
if isinstance(dtype, core.VarDesc.VarType):
if dtype == core.VarDesc.VarType.BOOL:
return 'bool'
elif dtype == core.VarDesc.VarType.FP16:
return 'float16'
elif dtype == core.VarDesc.VarType.FP32:
return 'float32'
elif dtype == core.VarDesc.VarType.FP64:
return 'float64'
elif dtype == core.VarDesc.VarType.INT8:
return 'int8'
elif dtype == core.VarDesc.VarType.INT16:
return 'int16'
elif dtype == core.VarDesc.VarType.INT32:
return 'int32'
elif dtype == core.VarDesc.VarType.INT64:
return 'int64'
elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8'
elif dtype == core.VarDesc.VarType.COMPLEX64:
return 'complex64'
elif dtype == core.VarDesc.VarType.COMPLEX128:
return 'complex128'
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
elif isinstance(dtype, type):
if dtype in [
np.bool, np.float16, np.float32, np.float64, np.int8, np.int16,
......
......@@ -23,6 +23,7 @@ from ..framework import Variable, Parameter, ParamBase
from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE
def monkey_patch_varbase():
......@@ -319,5 +320,20 @@ def monkey_patch_varbase():
("__name__", "Tensor")):
setattr(core.VarBase, method_name, method)
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
# So, we need to overwrite it to a more readable one.
# See details in https://github.com/pybind/pybind11/issues/2537.
origin = getattr(core.VarDesc.VarType, "__repr__")
def dtype_str(dtype):
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
prefix = 'paddle.'
return prefix + _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
else:
# for example, paddle.fluid.core.VarDesc.VarType.LOD_TENSOR
return origin(dtype)
setattr(core.VarDesc.VarType, "__repr__", dtype_str)
# patch math methods for varbase
monkey_patch_math_varbase()
......@@ -617,6 +617,16 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_print_tensor_dtype(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.rand([1])
a_str = str(a.dtype)
expected = 'paddle.float32'
self.assertEqual(a_str, expected)
paddle.enable_static()
class TestVarBaseSetitem(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册