From 370f0345b6d35a513c8e64d519a0edfc96b9276c Mon Sep 17 00:00:00 2001 From: pkpk Date: Thu, 24 Oct 2019 16:38:57 +0800 Subject: [PATCH] fix the bug in data_feeder.py (#20791) * test=develop * test=develop * test=develop * test=develop --- python/paddle/fluid/data_feeder.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index fa1d7f8cdd..a0dcb8cd9a 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -27,13 +27,7 @@ __all__ = ['DataFeeder'] def convert_dtype(dtype): - if isinstance(dtype, str): - if dtype in [ - 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', - 'int32', 'int64', 'uint8' - ]: - return dtype - else: + if isinstance(dtype, core.VarDesc.VarType): if dtype == core.VarDesc.VarType.BOOL: return 'bool' elif dtype == core.VarDesc.VarType.FP16: @@ -52,6 +46,19 @@ def convert_dtype(dtype): return 'int64' elif dtype == core.VarDesc.VarType.UINT8: return 'uint8' + else: + if dtype in [ + 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', + 'int32', 'int64', 'uint8', u'bool', u'float16', u'float32', + u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8' + ]: + # this code is a little bit dangerous, since error could happen + # when casting no-asci code to str in python2. + # but since the set itself is limited, so currently, it is good. + # however, jointly supporting python2 and python3, (as well as python4 maybe) + # may still be a long-lasting problem. + return str(dtype) + raise ValueError( "dtype must be any of [bool, float16, float32, float64, int8, int16, " "int32, int64, uint8]") -- GitLab