Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1a13626f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a13626f
编写于
1月 26, 2021
作者:
L
Leo Chen
提交者:
GitHub
1月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish printing dtype (#30682)
* polish printing dtype * fix special case
上级
5bf25d1e
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
42 addition
and
22 deletion
+42
-22
python/paddle/fluid/data_feeder.py
python/paddle/fluid/data_feeder.py
+16
-22
python/paddle/fluid/dygraph/varbase_patch_methods.py
python/paddle/fluid/dygraph/varbase_patch_methods.py
+16
-0
python/paddle/fluid/tests/unittests/test_var_base.py
python/paddle/fluid/tests/unittests/test_var_base.py
+10
-0
未找到文件。
python/paddle/fluid/data_feeder.py
浏览文件 @
1a13626f
...
...
@@ -26,31 +26,25 @@ from .framework import Variable, default_main_program, _current_expected_place,
from
.framework
import
_cpu_num
,
_cuda_ids
__all__
=
[
'DataFeeder'
]
_PADDLE_DTYPE_2_NUMPY_DTYPE
=
{
core
.
VarDesc
.
VarType
.
BOOL
:
'bool'
,
core
.
VarDesc
.
VarType
.
FP16
:
'float16'
,
core
.
VarDesc
.
VarType
.
FP32
:
'float32'
,
core
.
VarDesc
.
VarType
.
FP64
:
'float64'
,
core
.
VarDesc
.
VarType
.
INT8
:
'int8'
,
core
.
VarDesc
.
VarType
.
INT16
:
'int16'
,
core
.
VarDesc
.
VarType
.
INT32
:
'int32'
,
core
.
VarDesc
.
VarType
.
INT64
:
'int64'
,
core
.
VarDesc
.
VarType
.
UINT8
:
'uint8'
,
core
.
VarDesc
.
VarType
.
COMPLEX64
:
'complex64'
,
core
.
VarDesc
.
VarType
.
COMPLEX128
:
'complex128'
,
}
def
convert_dtype
(
dtype
):
if
isinstance
(
dtype
,
core
.
VarDesc
.
VarType
):
if
dtype
==
core
.
VarDesc
.
VarType
.
BOOL
:
return
'bool'
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
return
'float16'
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
return
'float32'
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP64
:
return
'float64'
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT8
:
return
'int8'
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT16
:
return
'int16'
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT32
:
return
'int32'
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT64
:
return
'int64'
elif
dtype
==
core
.
VarDesc
.
VarType
.
UINT8
:
return
'uint8'
elif
dtype
==
core
.
VarDesc
.
VarType
.
COMPLEX64
:
return
'complex64'
elif
dtype
==
core
.
VarDesc
.
VarType
.
COMPLEX128
:
return
'complex128'
if
dtype
in
_PADDLE_DTYPE_2_NUMPY_DTYPE
:
return
_PADDLE_DTYPE_2_NUMPY_DTYPE
[
dtype
]
elif
isinstance
(
dtype
,
type
):
if
dtype
in
[
np
.
bool
,
np
.
float16
,
np
.
float32
,
np
.
float64
,
np
.
int8
,
np
.
int16
,
...
...
python/paddle/fluid/dygraph/varbase_patch_methods.py
浏览文件 @
1a13626f
...
...
@@ -23,6 +23,7 @@ from ..framework import Variable, Parameter, ParamBase
from
.base
import
switch_to_static_graph
from
.math_op_patch
import
monkey_patch_math_varbase
from
.parallel
import
scale_loss
from
paddle.fluid.data_feeder
import
convert_dtype
,
_PADDLE_DTYPE_2_NUMPY_DTYPE
def
monkey_patch_varbase
():
...
...
@@ -319,5 +320,20 @@ def monkey_patch_varbase():
(
"__name__"
,
"Tensor"
)):
setattr
(
core
.
VarBase
,
method_name
,
method
)
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
# So, we need to overwrite it to a more readable one.
# See details in https://github.com/pybind/pybind11/issues/2537.
origin
=
getattr
(
core
.
VarDesc
.
VarType
,
"__repr__"
)
def
dtype_str
(
dtype
):
if
dtype
in
_PADDLE_DTYPE_2_NUMPY_DTYPE
:
prefix
=
'paddle.'
return
prefix
+
_PADDLE_DTYPE_2_NUMPY_DTYPE
[
dtype
]
else
:
# for example, paddle.fluid.core.VarDesc.VarType.LOD_TENSOR
return
origin
(
dtype
)
setattr
(
core
.
VarDesc
.
VarType
,
"__repr__"
,
dtype_str
)
# patch math methods for varbase
monkey_patch_math_varbase
()
python/paddle/fluid/tests/unittests/test_var_base.py
浏览文件 @
1a13626f
...
...
@@ -617,6 +617,16 @@ class TestVarBase(unittest.TestCase):
self
.
assertEqual
(
a_str
,
expected
)
paddle
.
enable_static
()
def
test_print_tensor_dtype
(
self
):
paddle
.
disable_static
(
paddle
.
CPUPlace
())
a
=
paddle
.
rand
([
1
])
a_str
=
str
(
a
.
dtype
)
expected
=
'paddle.float32'
self
.
assertEqual
(
a_str
,
expected
)
paddle
.
enable_static
()
class
TestVarBaseSetitem
(
unittest
.
TestCase
):
def
setUp
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录