未验证 提交 18e9dcdc 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] support OD level and skip dynamic loss scaling for bf16 (#53289)

* support OD level and skip dynamic loss scaling for bf16
上级 3278dec7
......@@ -129,7 +129,11 @@ inline phi::DataType GetAmpDestDtype(
->count(op_name)) {
dst_type = phi::DataType::FLOAT32;
} else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
if (amp_level == paddle::imperative::AmpLevel::OD) {
dst_type = phi::DataType::FLOAT32;
} else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
}
}
if (dst_type == amp_setting_dtype &&
......
......@@ -31,6 +31,7 @@ enum class AmpLevel {
O1, // amp, mixed fp32-fp16
O2, // almost fp16
O3, // fp16
OD, // only conv and matmul use low precison.
};
std::tuple<std::unordered_set<std::string>,
......
......@@ -2152,6 +2152,7 @@ void BindImperative(py::module *m_ptr) {
py::enum_<paddle::imperative::AmpLevel>(m, "AmpLevel", py::arithmetic())
.value("O0", paddle::imperative::AmpLevel::O0)
.value("OD", paddle::imperative::AmpLevel::OD)
.value("O1", paddle::imperative::AmpLevel::O1)
.value("O2", paddle::imperative::AmpLevel::O2)
.value("O3", paddle::imperative::AmpLevel::O3)
......
......@@ -91,10 +91,19 @@ BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = set()
# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32.
def white_list():
white_list = {
"float16": {"O1": FP16_WHITE_LIST, "O2": FP16_WHITE_LIST},
"bfloat16": {"O1": BF16_WHITE_LIST, "O2": BF16_WHITE_LIST},
"float16": {
"OD": FP16_WHITE_LIST,
"O1": FP16_WHITE_LIST,
"O2": FP16_WHITE_LIST,
},
"bfloat16": {
"OD": BF16_WHITE_LIST,
"O1": BF16_WHITE_LIST,
"O2": BF16_WHITE_LIST,
},
}
return white_list
......@@ -102,9 +111,10 @@ def white_list():
def black_list():
black_list = {
"float16": {
"OD": set(),
"O1": FP16_BLACK_LIST | FP16_EXTRA_BLACK_LIST,
"O2": FP16_EXTRA_BLACK_LIST,
},
"bfloat16": {"O1": BF16_BLACK_LIST, "O2": set()},
"bfloat16": {"OD": set(), "O1": BF16_BLACK_LIST, "O2": set()},
}
return black_list
......@@ -48,6 +48,7 @@ class AMPGlobalState:
self.model_parameters = []
self.use_master_grad = False
self.already_register_final_backward_hook = False
self.amp_dtype = 'float32'
def __setattr__(self, name, val):
self.__dict__[name] = val
......@@ -320,10 +321,8 @@ def amp_guard(
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
if not (level in ['O0', 'OD', 'O1', 'O2']):
raise ValueError("level should be O0, OD, O1 or O2.")
# check amp_dtype: float16 or bfloat16
dtype = dtype.lower()
......@@ -384,8 +383,11 @@ def amp_guard(
)
amp_dtype = dtype
amp_global_state().amp_dtype = amp_dtype
if level == 'O1':
if level == 'OD':
amp_level = AMP_LEVEL.OD
elif level == 'O1':
amp_level = AMP_LEVEL.O1
elif level == 'O2':
amp_level = AMP_LEVEL.O2
......@@ -642,22 +644,24 @@ def auto_cast(
):
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32 or float16) of each operator is decided
If enabled, the input data type (float32, float16 or bfloat16) of each operator is decided
by autocast algorithm for better performance.
Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in
imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode.
Commonly, it is used together with `GradScaler` and `decorator` to achieve Auto-Mixed-Precision in
imperative mode.
Args:
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support
fp16 calculation and are considered numerically-safe and performance-critical. These ops
will be converted to fp16.
custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16
calculation and are considered numerically-dangerous and whose effects may also be
observed in downstream ops. These ops will not be converted to fp16.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)
custom_white_list(set|list|tuple, optional): A default white list is already set. Usually there is no need to set custom white list.
The set of ops should be considered numerically-safe and performance-critical. These ops will be converted to float16/bfloat16.
custom_black_list(set|list|tuple, optional): A default black list is already set. You can set a custom black list according to the model.
The set of ops are considered numerically-dangerous and whose effects may also be observed in downstream ops. These ops will not be
converted to float16/bfloat16.
level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
Examples:
......
......@@ -24,6 +24,8 @@ from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import _dygraph_tracer, dygraph_only
from .auto_cast import amp_global_state
class OptimizerState(Enum):
INIT = 0
......@@ -179,6 +181,18 @@ class AmpScaler:
"""
check_type(var, "var", core.eager.Tensor, 'AmpScaler.scale()')
if (
self._enable
and amp_global_state().amp_dtype != 'float16'
and self._use_dynamic_loss_scaling
):
self._enable = False
self._use_dynamic_loss_scaling = False
warnings.warn(
'It is not recommended to use dynamic loss scaling for %s, so GradScaler is disable by default.'
% (amp_global_state().amp_dtype)
)
if not self._enable:
return var
......
......@@ -194,8 +194,11 @@ class TestAutoCast(unittest.TestCase):
class TestAmpScaler(unittest.TestCase):
def scale(self):
if not paddle.amp.is_float16_supported():
return
with fluid.dygraph.guard():
data = paddle.rand([10, 1024])
with paddle.amp.auto_cast(dtype='float16'):
data = paddle.rand([10, 1024])
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
scaled_data = scaler.scale(data)
self.assertEqual(
......@@ -333,9 +336,9 @@ class TestAmpScaler(unittest.TestCase):
)
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(inp_np)
out = model(data)
loss = paddle.mean(out)
with paddle.amp.auto_cast(dtype='float16'):
out = model(data)
loss = paddle.mean(out)
scaled_loss = scaler.scale(loss)
scaled_loss.backward()
optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss)
......@@ -348,6 +351,8 @@ class TestAmpScaler(unittest.TestCase):
)
def test_nan_inf(self):
if not paddle.amp.is_float16_supported():
return
self.nan_inf()
def step_update_exception(self):
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from amp_base_models import AmpTestBase
import paddle
class TestAutoCast(AmpTestBase):
def test_amp_OD_level(self):
conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
)
linear = paddle.nn.Linear(in_features=4, out_features=4)
with paddle.amp.auto_cast(level='OD'):
out1 = conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32'))
out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16')
out3 = linear(out2)
self.assertEqual(out1.dtype, paddle.float16)
self.assertEqual(out2.dtype, paddle.float32)
self.assertEqual(out3.dtype, paddle.float32)
class TestGradScaler(AmpTestBase):
def test_amp_grad_scaler(self):
model = paddle.nn.Conv2D(3, 2, 3)
optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=model.parameters()
)
scaler = paddle.amp.GradScaler()
data = paddle.rand([1, 3, 8, 8], dtype='float32')
paddle.amp.debugging.enable_operator_stats_collection()
with paddle.amp.auto_cast(
custom_black_list=['conv2d'], dtype='bfloat16'
):
out = model(data)
loss = out.mean()
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
paddle.amp.debugging.disable_operator_stats_collection()
op_list = paddle.fluid.core.get_low_precision_op_list()
self.assertEqual(scaler._enable, False)
self.assertEqual(scaler._use_dynamic_loss_scaling, False)
self.assertTrue('scale' not in op_list)
self.assertTrue('check_finite_and_unscale' not in op_list)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册