未验证 提交 370f0345 编写于 作者: P pkpk 提交者: GitHub

fix the bug in data_feeder.py (#20791)

* test=develop

* test=develop

* test=develop

* test=develop
上级 ac813bba
...@@ -27,13 +27,7 @@ __all__ = ['DataFeeder'] ...@@ -27,13 +27,7 @@ __all__ = ['DataFeeder']
def convert_dtype(dtype): def convert_dtype(dtype):
if isinstance(dtype, str): if isinstance(dtype, core.VarDesc.VarType):
if dtype in [
'bool', 'float16', 'float32', 'float64', 'int8', 'int16',
'int32', 'int64', 'uint8'
]:
return dtype
else:
if dtype == core.VarDesc.VarType.BOOL: if dtype == core.VarDesc.VarType.BOOL:
return 'bool' return 'bool'
elif dtype == core.VarDesc.VarType.FP16: elif dtype == core.VarDesc.VarType.FP16:
...@@ -52,6 +46,19 @@ def convert_dtype(dtype): ...@@ -52,6 +46,19 @@ def convert_dtype(dtype):
return 'int64' return 'int64'
elif dtype == core.VarDesc.VarType.UINT8: elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8' return 'uint8'
else:
if dtype in [
'bool', 'float16', 'float32', 'float64', 'int8', 'int16',
'int32', 'int64', 'uint8', u'bool', u'float16', u'float32',
u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8'
]:
# this code is a little bit dangerous, since error could happen
# when casting no-asci code to str in python2.
# but since the set itself is limited, so currently, it is good.
# however, jointly supporting python2 and python3, (as well as python4 maybe)
# may still be a long-lasting problem.
return str(dtype)
raise ValueError( raise ValueError(
"dtype must be any of [bool, float16, float32, float64, int8, int16, " "dtype must be any of [bool, float16, float32, float64, int8, int16, "
"int32, int64, uint8]") "int32, int64, uint8]")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册