未验证 提交 410f1629 编写于 作者: L Leo Chen 提交者: GitHub

support set_default_dtype bf16 (#51650)

* support set_default_dtype bf16

* support float
上级 c9ca7c35
......@@ -394,14 +394,18 @@ class LayerHelperBase:
and dtype != core.VarDesc.VarType.BF16
):
raise TypeError(
"Can not create parameter with default initializer when dtype is not float type. Set default_initializer to fit the parameter dtype!"
"Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16'] type. Set default_initializer to fit the parameter dtype!"
)
else:
if not (
dtype.startswith("float") or dtype in ["double", "uint16"]
):
if dtype not in [
'float16',
'float32',
'float64',
'bfloat16',
'float',
]:
raise TypeError(
"Can not create parameter with default initializer when dtype is not float type. Set default_initializer to fit the parameter dtype!"
"Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16', 'float'] type. Set default_initializer to fit the parameter dtype!"
)
if is_bias:
attr._set_default_bias_initializer()
......
......@@ -16,6 +16,7 @@ import unittest
import numpy as np
import paddle
from paddle.framework import get_default_dtype, set_default_dtype
......@@ -35,6 +36,9 @@ class TestDefaultType(unittest.TestCase):
set_default_dtype("float16")
self.assertEqual("float16", get_default_dtype())
set_default_dtype("bfloat16")
self.assertEqual("bfloat16", get_default_dtype())
set_default_dtype(np.float64)
self.assertEqual("float64", get_default_dtype())
......@@ -45,6 +49,13 @@ class TestDefaultType(unittest.TestCase):
self.assertEqual("float16", get_default_dtype())
class TestDefaultTypeInLayer(unittest.TestCase):
def test_bfloat16(self):
set_default_dtype("bfloat16")
linear = paddle.nn.Linear(10, 20)
self.assertEqual(linear.weight.dtype, paddle.bfloat16)
class TestRaiseError(unittest.TestCase):
def test_error(self):
self.assertRaises(TypeError, set_default_dtype, "int32")
......
......@@ -27,7 +27,7 @@ def set_default_dtype(d):
Args:
d(string|np.dtype): the dtype to make the default. It only
supports float16, float32 and float64.
supports float16, bfloat16, float32 and float64.
Returns:
None.
......@@ -50,14 +50,14 @@ def set_default_dtype(d):
)
else:
# This branch is for np.dtype and str
if d in ['float16', 'float32', 'float64']:
if d in ['float16', 'float32', 'float64', 'bfloat16']:
# NOTE(SigureMo): Since the np.dtype object is not an instance of
# type, so it will not be handled by the previous branch. We need
# to convert it to str here.
d = str(d)
else:
raise TypeError(
"set_default_dtype only supports [float16, float32, float64] "
"set_default_dtype only supports [float16, float32, float64, bfloat16] "
", but received %s" % str(d)
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册