diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index fa1d7f8cddb899859a13a8068544b60ffb7598f4..a0dcb8cd9a7c10ab14a9cef4620ad71b78654cae 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -27,13 +27,7 @@ __all__ = ['DataFeeder'] def convert_dtype(dtype): - if isinstance(dtype, str): - if dtype in [ - 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', - 'int32', 'int64', 'uint8' - ]: - return dtype - else: + if isinstance(dtype, core.VarDesc.VarType): if dtype == core.VarDesc.VarType.BOOL: return 'bool' elif dtype == core.VarDesc.VarType.FP16: @@ -52,6 +46,19 @@ def convert_dtype(dtype): return 'int64' elif dtype == core.VarDesc.VarType.UINT8: return 'uint8' + else: + if dtype in [ + 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', + 'int32', 'int64', 'uint8', u'bool', u'float16', u'float32', + u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8' + ]: + # this code is a little bit dangerous, since error could happen + # when casting no-asci code to str in python2. + # but since the set itself is limited, so currently, it is good. + # however, jointly supporting python2 and python3, (as well as python4 maybe) + # may still be a long-lasting problem. + return str(dtype) + raise ValueError( "dtype must be any of [bool, float16, float32, float64, int8, int16, " "int32, int64, uint8]")