未验证 提交 a64d50b7 编写于 作者: L liuruyan 提交者: GitHub

Add layer func: float(), half(), bfloat16(). (#51635)

上级 998235e6
...@@ -373,7 +373,16 @@ void BindPlace(pybind11::module &m) { // NOLINT ...@@ -373,7 +373,16 @@ void BindPlace(pybind11::module &m) { // NOLINT
#endif #endif
.def("__repr__", string::to_string<const platform::CUDAPlace &>) .def("__repr__", string::to_string<const platform::CUDAPlace &>)
.def("__str__", string::to_string<const platform::CUDAPlace &>); .def("__str__", string::to_string<const platform::CUDAPlace &>);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 53 support float16
return platform::GetGPUComputeCapability(place.device) >= 53;
});
m.def("is_bfloat16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 80 support bfloat16
return platform::GetGPUComputeCapability(place.device) >= 80;
});
#endif
py::class_<platform::XPUPlace> xpuplace(m, "XPUPlace", R"DOC( py::class_<platform::XPUPlace> xpuplace(m, "XPUPlace", R"DOC(
**Note**: **Note**:
Examples: Examples:
...@@ -492,7 +501,18 @@ void BindPlace(pybind11::module &m) { // NOLINT ...@@ -492,7 +501,18 @@ void BindPlace(pybind11::module &m) { // NOLINT
&IsSamePlace<platform::CPUPlace, platform::CUDAPinnedPlace>) &IsSamePlace<platform::CPUPlace, platform::CUDAPinnedPlace>)
.def("__repr__", string::to_string<const platform::CPUPlace &>) .def("__repr__", string::to_string<const platform::CPUPlace &>)
.def("__str__", string::to_string<const platform::CPUPlace &>); .def("__str__", string::to_string<const platform::CPUPlace &>);
m.def("is_float16_supported",
[](const platform::CPUPlace &place) -> bool { return false; });
m.def("is_bfloat16_supported", [](const platform::CPUPlace &place) -> bool {
#ifndef PADDLE_WITH_MKLDNN
return false;
#else
if (phi::backends::cpu::MayIUse(phi::backends::cpu::cpu_isa_t::avx512_core))
return true;
else
return false;
#endif
});
py::class_<paddle::platform::CUDAPinnedPlace> cudapinnedplace( py::class_<paddle::platform::CUDAPinnedPlace> cudapinnedplace(
m, "CUDAPinnedPlace", R"DOC( m, "CUDAPinnedPlace", R"DOC(
CUDAPinnedPlace is a descriptor of a device. CUDAPinnedPlace is a descriptor of a device.
......
...@@ -1960,17 +1960,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1960,17 +1960,6 @@ All parameter, weight, gradient are variables in Paddle.
py::arg("sleep_inter") = 0, py::arg("sleep_inter") = 0,
py::arg("redirect_stderr") = false); py::arg("redirect_stderr") = false);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 53 support float16
return platform::GetGPUComputeCapability(place.device) >= 53;
});
m.def("is_bfloat16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 80 support bfloat16
return platform::GetGPUComputeCapability(place.device) >= 80;
});
#endif
m.def("set_feed_variable", m.def("set_feed_variable",
static_cast<void (*)( // NOLINT static_cast<void (*)( // NOLINT
Scope *, Scope *,
......
...@@ -28,4 +28,68 @@ from .grad_scaler import OptimizerState # noqa: F401 ...@@ -28,4 +28,68 @@ from .grad_scaler import OptimizerState # noqa: F401
from . import debugging # noqa: F401 from . import debugging # noqa: F401
__all__ = ['auto_cast', 'GradScaler', 'decorate'] from paddle.fluid import core
from paddle.fluid.framework import (
_current_expected_place,
_get_paddle_place,
)
__all__ = [
'auto_cast',
'GradScaler',
'decorate',
'is_float16_supported',
'is_bfloat16_supported',
]
def is_float16_supported(device=None):
"""
Determine whether the place supports float16 in the auto-mixed-precision training.
Args:
device (str|None, optional): Specify the running device.
It can be ``cpu``, ``gpu``, ``xpu``, ``gpu:x`` and ``xpu:x``,
where ``x`` is the index of the GPUs or XPUs. if device is None, the device is the current device. Default: None.
Examples:
.. code-block:: python
import paddle
paddle.amp.is_float16_supported() # True or False
"""
device = (
_current_expected_place()
if device is None
else _get_paddle_place(device)
)
return core.is_float16_supported(device)
def is_bfloat16_supported(device=None):
"""
Determine whether the place supports bfloat16 in the auto-mixed-precision training.
Args:
device (str|None, optional): Specify the running device.
It can be ``cpu``, ``gpu``, ``xpu``, ``gpu:x`` and ``xpu:x``,
where ``x`` is the index of the GPUs or XPUs. if device is None, the device is the current device. Default: None.
Examples:
.. code-block:: python
import paddle
paddle.amp.is_bfloat16_supported() # True or False
"""
device = (
_current_expected_place()
if device is None
else _get_paddle_place(device)
)
return core.is_bfloat16_supported(device)
...@@ -22,7 +22,7 @@ import weakref ...@@ -22,7 +22,7 @@ import weakref
import numpy as np import numpy as np
import paddle import paddle
from paddle import profiler from paddle import nn, profiler
from paddle.fluid import core, framework, unique_name from paddle.fluid import core, framework, unique_name
from paddle.fluid.core import VarDesc from paddle.fluid.core import VarDesc
from paddle.fluid.dygraph import no_grad from paddle.fluid.dygraph import no_grad
...@@ -125,6 +125,13 @@ def _addindent(string, indent): ...@@ -125,6 +125,13 @@ def _addindent(string, indent):
return s1[0] + '\n' + '\n'.join(s2) return s1[0] + '\n' + '\n'.join(s2)
def _layer_trans_dtype(layer, dtype, excluded_layers):
if type(layer) in excluded_layers:
return
layer._to_impl(dtype=dtype, floating_only=True, include_sublayers=False)
class LayerObjectHelper(LayerHelperBase): class LayerObjectHelper(LayerHelperBase):
def __init__(self, name): def __init__(self, name):
super().__init__(name, layer_type=name) super().__init__(name, layer_type=name)
...@@ -2146,3 +2153,170 @@ class Layer: ...@@ -2146,3 +2153,170 @@ class Layer:
# [aliases] Compatible with old method names # [aliases] Compatible with old method names
set_dict = set_state_dict set_dict = set_state_dict
load_dict = set_state_dict load_dict = set_state_dict
def float(self, excluded_layers=None):
'''
Casts all floating point parameters and buffers to ``float`` data type.
Parameters:
excluded_layers(nn.Layer|list|None, optional): Specify the layers that need to be kept original data type. if excluded_layers is None, casts all floating point parameters and buffers. Default: None.
Returns:
Layer: self
Examples:
.. code-block:: python
import paddle
class Model(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.linear = paddle.nn.Linear(1, 1)
self.dropout = paddle.nn.Dropout(p=0.5)
def forward(self, input):
out = self.linear(input)
out = self.dropout(out)
return out
model = Model()
model.float()
'''
excluded_layers = [] if excluded_layers is None else excluded_layers
if isinstance(excluded_layers, type):
excluded_layers = [excluded_layers]
elif isinstance(excluded_layers, list):
pass
else:
raise TypeError(
"excluded_layers should be type nn.Layer or list, but got %s.",
type(excluded_layers).__name__,
)
def layer_trans(layer):
_layer_trans_dtype(layer, paddle.float32, excluded_layers)
return self.apply(layer_trans)
def float16(self, excluded_layers=None):
'''
Casts all floating point parameters and buffers to ``float16`` data type.
.. note::
``nn.BatchNorm`` does not support ``bfloat16`` weights, so it would not be converted by default.
Parameters:
excluded_layers(nn.Layer|list|None, optional): Specify the layers that need to be kept original data type. if excluded_layers is None, casts all floating point parameters and buffers except ``nn.BatchNorm``. Default: None.
Returns:
Layer: self
Examples:
.. code-block:: python
import paddle
class Model(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.linear = paddle.nn.Linear(1, 1)
self.dropout = paddle.nn.Dropout(p=0.5)
def forward(self, input):
out = self.linear(input)
out = self.dropout(out)
return out
model = Model()
model.float16()
'''
if paddle.amp.is_float16_supported() is False:
warnings.warn(
"Paddle compiled by the user does not support float16, so keep original data type."
)
return self
excluded_layers = (
[nn.BatchNorm] if excluded_layers is None else excluded_layers
)
if isinstance(excluded_layers, type):
excluded_layers = [excluded_layers]
elif isinstance(excluded_layers, list):
pass
else:
raise TypeError(
"excluded_layers should be type nn.Layer or list, but got %s.",
type(excluded_layers).__name__,
)
def layer_trans(layer):
_layer_trans_dtype(layer, paddle.float16, excluded_layers)
return self.apply(layer_trans)
def bfloat16(self, excluded_layers=None):
'''
Casts all floating point parameters and buffers to ``bfloat16`` data type.
.. note::
``nn.BatchNorm`` does not support ``bfloat16`` weights, so it would not be converted by default.
Parameters:
excluded_layers(nn.Layer|list|None, optional): Specify the layers that need to be kept original data type. if excluded_layers is None, casts all floating point parameters and buffers except ``nn.BatchNorm``. Default: None.
Returns:
Layer: self
Examples:
.. code-block:: python
import paddle
class Model(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.linear = paddle.nn.Linear(1, 1)
self.dropout = paddle.nn.Dropout(p=0.5)
def forward(self, input):
out = self.linear(input)
out = self.dropout(out)
return out
model = Model()
model.bfloat16()
'''
if paddle.amp.is_bfloat16_supported() is False:
warnings.warn(
"Paddle compiled by the user does not support bfloat16, so keep original data type."
)
return self
excluded_layers = (
[nn.BatchNorm] if excluded_layers is None else excluded_layers
)
if isinstance(excluded_layers, type):
excluded_layers = [excluded_layers]
elif isinstance(excluded_layers, list):
pass
else:
raise TypeError(
"excluded_layers should be type nn.Layer or list, but got %s.",
type(excluded_layers).__name__,
)
def layer_trans(layer):
_layer_trans_dtype(layer, paddle.bfloat16, excluded_layers)
return self.apply(layer_trans)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.fluid import core
class MyModel(paddle.nn.Layer):
def __init__(self, input_size, hidden_size):
super().__init__()
self.linear1 = paddle.nn.Linear(input_size, hidden_size)
self.linear2 = paddle.nn.Linear(hidden_size, hidden_size)
self.linear3 = paddle.nn.Linear(hidden_size, 1)
self.batchnorm = paddle.nn.Sequential(paddle.nn.BatchNorm(hidden_size))
register_buffer_in_temp = paddle.ones([4, 6])
self.register_buffer('register_buffer_in', register_buffer_in_temp)
def forward(self, inputs):
x = self.linear1(inputs)
x = F.relu(x)
x = self.batchnorm(x)
x = self.linear3(x)
return x
@unittest.skipIf(
not core.is_compiled_with_cuda(), "Require compiled with CUDA."
)
class TestDtypeConvert(unittest.TestCase):
def setUp(self):
self.batch_size, self.input_size, self.hidden_size = 128, 128, 256
def verify_trans_dtype(
self, test_type=None, excluded_layers=None, corrected_dtype=None
):
model = MyModel(self.input_size, self.hidden_size)
if test_type == 'float16':
model.float16(excluded_layers=excluded_layers)
elif test_type == 'bfloat16':
model.bfloat16(excluded_layers=excluded_layers)
else:
model.float(excluded_layers=excluded_layers)
for name, para in model.named_parameters():
if 'linear' in name:
self.assertEqual(para.dtype, corrected_dtype)
elif 'batchnorm' in name:
if excluded_layers is None:
self.assertEqual(para.dtype, paddle.float32)
else:
self.assertEqual(para.dtype, paddle.float16)
def test_excluded_layers(self):
self.verify_trans_dtype(
test_type='float16',
excluded_layers=[nn.Linear],
corrected_dtype=paddle.float32,
)
self.verify_trans_dtype(
test_type='float16',
excluded_layers=nn.Linear,
corrected_dtype=paddle.float32,
)
def test_float16(self):
self.verify_trans_dtype(
test_type='float16',
corrected_dtype=paddle.float16,
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] >= 8.0,
"run test when maximum gpu's compute capability is 8.0.",
)
def test_unsupported_bfloat16(self):
self.verify_trans_dtype(
test_type='bfloat16',
corrected_dtype=paddle.float32,
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8.0,
"run test when gpu's compute capability is at least 8.0.",
)
def test_supported_bfloat16(self):
self.verify_trans_dtype(
test_type='bfloat16',
corrected_dtype=paddle.bfloat16,
)
def test_float32(self):
paddle.set_default_dtype('float16')
self.verify_trans_dtype(
test_type='float32',
corrected_dtype=paddle.float32,
)
paddle.set_default_dtype('float32')
def test_excluded_layers_type_error(self):
self.assertRaises(
TypeError, self.verify_trans_dtype, excluded_layers=111
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "Require compiled with CUDA."
)
class TestSupportedTypeInfo(unittest.TestCase):
def test_cpu(self):
res = paddle.amp.is_float16_supported('cpu')
self.assertEqual(res, False)
res = paddle.amp.is_bfloat16_supported('cpu')
self.assertEqual(res, True)
def test_gpu_fp16_supported(self):
res = paddle.amp.is_float16_supported()
self.assertEqual(res, True)
res = paddle.amp.is_float16_supported('gpu')
self.assertEqual(res, True)
res = paddle.amp.is_float16_supported('gpu:0')
self.assertEqual(res, True)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] >= 8.0,
"run test when maximum gpu's compute capability is 8.0.",
)
def test_gpu_bf16_unsupported(self):
res = paddle.amp.is_bfloat16_supported()
self.assertEqual(res, False)
res = paddle.amp.is_bfloat16_supported('gpu')
self.assertEqual(res, False)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8.0,
"run test when gpu's compute capability is at least 8.0.",
)
def test_gpu_bf16_supported(self):
res = paddle.amp.is_bfloat16_supported()
self.assertEqual(res, True)
res = paddle.amp.is_bfloat16_supported('gpu')
self.assertEqual(res, True)
def test_device_value_error(self):
self.assertRaises(
ValueError, paddle.amp.is_float16_supported, device='xxx'
)
self.assertRaises(
ValueError, paddle.amp.is_float16_supported, device=111
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册