未验证 提交 410f1629 编写于 作者: L Leo Chen 提交者: GitHub

support set_default_dtype bf16 (#51650)

* support set_default_dtype bf16

* support float
上级 c9ca7c35
...@@ -394,14 +394,18 @@ class LayerHelperBase: ...@@ -394,14 +394,18 @@ class LayerHelperBase:
and dtype != core.VarDesc.VarType.BF16 and dtype != core.VarDesc.VarType.BF16
): ):
raise TypeError( raise TypeError(
"Can not create parameter with default initializer when dtype is not float type. Set default_initializer to fit the parameter dtype!" "Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16'] type. Set default_initializer to fit the parameter dtype!"
) )
else: else:
if not ( if dtype not in [
dtype.startswith("float") or dtype in ["double", "uint16"] 'float16',
): 'float32',
'float64',
'bfloat16',
'float',
]:
raise TypeError( raise TypeError(
"Can not create parameter with default initializer when dtype is not float type. Set default_initializer to fit the parameter dtype!" "Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16', 'float'] type. Set default_initializer to fit the parameter dtype!"
) )
if is_bias: if is_bias:
attr._set_default_bias_initializer() attr._set_default_bias_initializer()
......
...@@ -16,6 +16,7 @@ import unittest ...@@ -16,6 +16,7 @@ import unittest
import numpy as np import numpy as np
import paddle
from paddle.framework import get_default_dtype, set_default_dtype from paddle.framework import get_default_dtype, set_default_dtype
...@@ -35,6 +36,9 @@ class TestDefaultType(unittest.TestCase): ...@@ -35,6 +36,9 @@ class TestDefaultType(unittest.TestCase):
set_default_dtype("float16") set_default_dtype("float16")
self.assertEqual("float16", get_default_dtype()) self.assertEqual("float16", get_default_dtype())
set_default_dtype("bfloat16")
self.assertEqual("bfloat16", get_default_dtype())
set_default_dtype(np.float64) set_default_dtype(np.float64)
self.assertEqual("float64", get_default_dtype()) self.assertEqual("float64", get_default_dtype())
...@@ -45,6 +49,13 @@ class TestDefaultType(unittest.TestCase): ...@@ -45,6 +49,13 @@ class TestDefaultType(unittest.TestCase):
self.assertEqual("float16", get_default_dtype()) self.assertEqual("float16", get_default_dtype())
class TestDefaultTypeInLayer(unittest.TestCase):
def test_bfloat16(self):
set_default_dtype("bfloat16")
linear = paddle.nn.Linear(10, 20)
self.assertEqual(linear.weight.dtype, paddle.bfloat16)
class TestRaiseError(unittest.TestCase): class TestRaiseError(unittest.TestCase):
def test_error(self): def test_error(self):
self.assertRaises(TypeError, set_default_dtype, "int32") self.assertRaises(TypeError, set_default_dtype, "int32")
......
...@@ -27,7 +27,7 @@ def set_default_dtype(d): ...@@ -27,7 +27,7 @@ def set_default_dtype(d):
Args: Args:
d(string|np.dtype): the dtype to make the default. It only d(string|np.dtype): the dtype to make the default. It only
supports float16, float32 and float64. supports float16, bfloat16, float32 and float64.
Returns: Returns:
None. None.
...@@ -50,14 +50,14 @@ def set_default_dtype(d): ...@@ -50,14 +50,14 @@ def set_default_dtype(d):
) )
else: else:
# This branch is for np.dtype and str # This branch is for np.dtype and str
if d in ['float16', 'float32', 'float64']: if d in ['float16', 'float32', 'float64', 'bfloat16']:
# NOTE(SigureMo): Since the np.dtype object is not an instance of # NOTE(SigureMo): Since the np.dtype object is not an instance of
# type, so it will not be handled by the previous branch. We need # type, so it will not be handled by the previous branch. We need
# to convert it to str here. # to convert it to str here.
d = str(d) d = str(d)
else: else:
raise TypeError( raise TypeError(
"set_default_dtype only supports [float16, float32, float64] " "set_default_dtype only supports [float16, float32, float64, bfloat16] "
", but received %s" % str(d) ", but received %s" % str(d)
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册