diff --git a/python/paddle/fluid/tests/unittests/test_default_dtype.py b/python/paddle/fluid/tests/unittests/test_default_dtype.py index eba4ec3420f2d5b56a880ce314b7168fb5e84154..057933fc7a735c2732cd651e83e99ddfa747b8a8 100644 --- a/python/paddle/fluid/tests/unittests/test_default_dtype.py +++ b/python/paddle/fluid/tests/unittests/test_default_dtype.py @@ -33,8 +33,28 @@ class TestDefaultType(unittest.TestCase): set_default_dtype("float64") self.assertEqual("float64", get_default_dtype()) - set_default_dtype(np.int32) - self.assertEqual("int32", get_default_dtype()) + set_default_dtype("float32") + self.assertEqual("float32", get_default_dtype()) + + set_default_dtype("float16") + self.assertEqual("float16", get_default_dtype()) + + set_default_dtype(np.float64) + self.assertEqual("float64", get_default_dtype()) + + set_default_dtype(np.float32) + self.assertEqual("float32", get_default_dtype()) + + set_default_dtype(np.float16) + self.assertEqual("float16", get_default_dtype()) + + +class TestRaiseError(unittest.TestCase): + def test_error(self): + self.assertRaises(TypeError, set_default_dtype, "int32") + self.assertRaises(TypeError, set_default_dtype, np.int32) + self.assertRaises(TypeError, set_default_dtype, "int64") + self.assertRaises(TypeError, set_default_dtype, np.int64) if __name__ == '__main__': diff --git a/python/paddle/framework/framework.py b/python/paddle/framework/framework.py index 4d5b2c8e6fcb13826d9e1a0d6738c351be7cf0b1..41ec18ce32d3036c3db86aaa98053f59ff61f717 100644 --- a/python/paddle/framework/framework.py +++ b/python/paddle/framework/framework.py @@ -15,6 +15,7 @@ # TODO: define framework api from paddle.fluid.layer_helper_base import LayerHelperBase from paddle.fluid.data_feeder import convert_dtype +import numpy as np __all__ = ['set_default_dtype', 'get_default_dtype'] @@ -24,7 +25,8 @@ def set_default_dtype(d): Set default dtype. The default dtype is initially float32 Args: - d(string|np.dtype): the dtype to make the default + d(string|np.dtype): the dtype to make the default. It only + supports float16, float32 and float64. Returns: None. @@ -36,13 +38,35 @@ def set_default_dtype(d): paddle.set_default_dtype("float32") """ - d = convert_dtype(d) + if isinstance(d, type): + if d in [np.float16, np.float32, np.float64]: + d = d.__name__ + else: + raise TypeError( + "set_default_dtype only supports [float16, float32, float64] " + ", but received %s" % d.__name__) + else: + if d in [ + 'float16', 'float32', 'float64', u'float16', u'float32', + u'float64' + ]: + # this code is a little bit dangerous, since error could happen + # when casting no-ascii code to str in python2. + # but since the set itself is limited, so currently, it is good. + # however, jointly supporting python2 and python3, (as well as python4 maybe) + # may still be a long-lasting problem. + d = str(d) + else: + raise TypeError( + "set_default_dtype only supports [float16, float32, float64] " + ", but received %s" % str(d)) + LayerHelperBase.set_default_dtype(d) def get_default_dtype(): """ - Get the current default dtype. The default dtype is initially float32 + Get the current default dtype. The default dtype is initially float32. Args: None.