From 9c76ba6cd2dea1525f98d3aaa3430064ed8767df Mon Sep 17 00:00:00 2001 From: pkpk Date: Mon, 28 Oct 2019 11:12:45 +0800 Subject: [PATCH] cherry-pick data_feeder bug fix to 1.6 (#20840) * test=document_fix * test=release/1.6 --- python/paddle/fluid/data_feeder.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index fa1d7f8cddb..a0dcb8cd9a7 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -27,13 +27,7 @@ __all__ = ['DataFeeder'] def convert_dtype(dtype): - if isinstance(dtype, str): - if dtype in [ - 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', - 'int32', 'int64', 'uint8' - ]: - return dtype - else: + if isinstance(dtype, core.VarDesc.VarType): if dtype == core.VarDesc.VarType.BOOL: return 'bool' elif dtype == core.VarDesc.VarType.FP16: @@ -52,6 +46,19 @@ def convert_dtype(dtype): return 'int64' elif dtype == core.VarDesc.VarType.UINT8: return 'uint8' + else: + if dtype in [ + 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', + 'int32', 'int64', 'uint8', u'bool', u'float16', u'float32', + u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8' + ]: + # this code is a little bit dangerous, since error could happen + # when casting no-asci code to str in python2. + # but since the set itself is limited, so currently, it is good. + # however, jointly supporting python2 and python3, (as well as python4 maybe) + # may still be a long-lasting problem. + return str(dtype) + raise ValueError( "dtype must be any of [bool, float16, float32, float64, int8, int16, " "int32, int64, uint8]") -- GitLab