未验证 提交 b918d063 编写于 作者: Z zhangbo9674 提交者: GitHub

refine tensor.dtype print formate for bfloat16 (#44055)

* refine tensor.dtype for bloat16

* refine test

* revert

* refine bfloat16 print
上级 52607cf8
...@@ -1039,8 +1039,11 @@ def monkey_patch_varbase(): ...@@ -1039,8 +1039,11 @@ def monkey_patch_varbase():
def dtype_str(dtype): def dtype_str(dtype):
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE: if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
numpy_dtype = _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
if numpy_dtype == 'uint16':
numpy_dtype = 'bfloat16'
prefix = 'paddle.' prefix = 'paddle.'
return prefix + _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] return prefix + numpy_dtype
else: else:
# for example, paddle.fluid.core.VarDesc.VarType.LOD_TENSOR # for example, paddle.fluid.core.VarDesc.VarType.LOD_TENSOR
return origin(dtype) return origin(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册