From 256dc6ee4883885439dca3325583b17c44fd285e Mon Sep 17 00:00:00 2001 From: ShenLiang <2282912238@qq.com> Date: Thu, 20 Aug 2020 18:51:30 +0800 Subject: [PATCH] set_default_dtype only support float (#26435) --- .../tests/unittests/test_default_dtype.py | 24 +++++++++++++-- python/paddle/framework/framework.py | 30 +++++++++++++++++-- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_default_dtype.py b/python/paddle/fluid/tests/unittests/test_default_dtype.py index eba4ec3420f..057933fc7a7 100644 --- a/python/paddle/fluid/tests/unittests/test_default_dtype.py +++ b/python/paddle/fluid/tests/unittests/test_default_dtype.py @@ -33,8 +33,28 @@ class TestDefaultType(unittest.TestCase): set_default_dtype("float64") self.assertEqual("float64", get_default_dtype()) - set_default_dtype(np.int32) - self.assertEqual("int32", get_default_dtype()) + set_default_dtype("float32") + self.assertEqual("float32", get_default_dtype()) + + set_default_dtype("float16") + self.assertEqual("float16", get_default_dtype()) + + set_default_dtype(np.float64) + self.assertEqual("float64", get_default_dtype()) + + set_default_dtype(np.float32) + self.assertEqual("float32", get_default_dtype()) + + set_default_dtype(np.float16) + self.assertEqual("float16", get_default_dtype()) + + +class TestRaiseError(unittest.TestCase): + def test_error(self): + self.assertRaises(TypeError, set_default_dtype, "int32") + self.assertRaises(TypeError, set_default_dtype, np.int32) + self.assertRaises(TypeError, set_default_dtype, "int64") + self.assertRaises(TypeError, set_default_dtype, np.int64) if __name__ == '__main__': diff --git a/python/paddle/framework/framework.py b/python/paddle/framework/framework.py index 4d5b2c8e6fc..41ec18ce32d 100644 --- a/python/paddle/framework/framework.py +++ b/python/paddle/framework/framework.py @@ -15,6 +15,7 @@ # TODO: define framework api from paddle.fluid.layer_helper_base import LayerHelperBase from paddle.fluid.data_feeder import convert_dtype +import numpy as np __all__ = ['set_default_dtype', 'get_default_dtype'] @@ -24,7 +25,8 @@ def set_default_dtype(d): Set default dtype. The default dtype is initially float32 Args: - d(string|np.dtype): the dtype to make the default + d(string|np.dtype): the dtype to make the default. It only + supports float16, float32 and float64. Returns: None. @@ -36,13 +38,35 @@ def set_default_dtype(d): paddle.set_default_dtype("float32") """ - d = convert_dtype(d) + if isinstance(d, type): + if d in [np.float16, np.float32, np.float64]: + d = d.__name__ + else: + raise TypeError( + "set_default_dtype only supports [float16, float32, float64] " + ", but received %s" % d.__name__) + else: + if d in [ + 'float16', 'float32', 'float64', u'float16', u'float32', + u'float64' + ]: + # this code is a little bit dangerous, since error could happen + # when casting no-ascii code to str in python2. + # but since the set itself is limited, so currently, it is good. + # however, jointly supporting python2 and python3, (as well as python4 maybe) + # may still be a long-lasting problem. + d = str(d) + else: + raise TypeError( + "set_default_dtype only supports [float16, float32, float64] " + ", but received %s" % str(d)) + LayerHelperBase.set_default_dtype(d) def get_default_dtype(): """ - Get the current default dtype. The default dtype is initially float32 + Get the current default dtype. The default dtype is initially float32. Args: None. -- GitLab