From 410f1629425b060cbb4bc664980b4046e85105a8 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 15 Mar 2023 16:01:37 +0800 Subject: [PATCH] support set_default_dtype bf16 (#51650) * support set_default_dtype bf16 * support float --- python/paddle/fluid/layer_helper_base.py | 14 +++++++++----- .../fluid/tests/unittests/test_default_dtype.py | 11 +++++++++++ python/paddle/framework/framework.py | 6 +++--- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/layer_helper_base.py b/python/paddle/fluid/layer_helper_base.py index 66f9f75ecc6..16fddd7520a 100644 --- a/python/paddle/fluid/layer_helper_base.py +++ b/python/paddle/fluid/layer_helper_base.py @@ -394,14 +394,18 @@ class LayerHelperBase: and dtype != core.VarDesc.VarType.BF16 ): raise TypeError( - "Can not create parameter with default initializer when dtype is not float type. Set default_initializer to fit the parameter dtype!" + "Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16'] type. Set default_initializer to fit the parameter dtype!" ) else: - if not ( - dtype.startswith("float") or dtype in ["double", "uint16"] - ): + if dtype not in [ + 'float16', + 'float32', + 'float64', + 'bfloat16', + 'float', + ]: raise TypeError( - "Can not create parameter with default initializer when dtype is not float type. Set default_initializer to fit the parameter dtype!" + "Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16', 'float'] type. Set default_initializer to fit the parameter dtype!" ) if is_bias: attr._set_default_bias_initializer() diff --git a/python/paddle/fluid/tests/unittests/test_default_dtype.py b/python/paddle/fluid/tests/unittests/test_default_dtype.py index ceaa9447cfc..37ac6dd476b 100644 --- a/python/paddle/fluid/tests/unittests/test_default_dtype.py +++ b/python/paddle/fluid/tests/unittests/test_default_dtype.py @@ -16,6 +16,7 @@ import unittest import numpy as np +import paddle from paddle.framework import get_default_dtype, set_default_dtype @@ -35,6 +36,9 @@ class TestDefaultType(unittest.TestCase): set_default_dtype("float16") self.assertEqual("float16", get_default_dtype()) + set_default_dtype("bfloat16") + self.assertEqual("bfloat16", get_default_dtype()) + set_default_dtype(np.float64) self.assertEqual("float64", get_default_dtype()) @@ -45,6 +49,13 @@ class TestDefaultType(unittest.TestCase): self.assertEqual("float16", get_default_dtype()) +class TestDefaultTypeInLayer(unittest.TestCase): + def test_bfloat16(self): + set_default_dtype("bfloat16") + linear = paddle.nn.Linear(10, 20) + self.assertEqual(linear.weight.dtype, paddle.bfloat16) + + class TestRaiseError(unittest.TestCase): def test_error(self): self.assertRaises(TypeError, set_default_dtype, "int32") diff --git a/python/paddle/framework/framework.py b/python/paddle/framework/framework.py index cbcee9ead70..2e9df1a8ee3 100644 --- a/python/paddle/framework/framework.py +++ b/python/paddle/framework/framework.py @@ -27,7 +27,7 @@ def set_default_dtype(d): Args: d(string|np.dtype): the dtype to make the default. It only - supports float16, float32 and float64. + supports float16, bfloat16, float32 and float64. Returns: None. @@ -50,14 +50,14 @@ def set_default_dtype(d): ) else: # This branch is for np.dtype and str - if d in ['float16', 'float32', 'float64']: + if d in ['float16', 'float32', 'float64', 'bfloat16']: # NOTE(SigureMo): Since the np.dtype object is not an instance of # type, so it will not be handled by the previous branch. We need # to convert it to str here. d = str(d) else: raise TypeError( - "set_default_dtype only supports [float16, float32, float64] " + "set_default_dtype only supports [float16, float32, float64, bfloat16] " ", but received %s" % str(d) ) -- GitLab