未验证 提交 18e9dcdc 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] support OD level and skip dynamic loss scaling for bf16 (#53289)

* support OD level and skip dynamic loss scaling for bf16
上级 3278dec7
...@@ -129,7 +129,11 @@ inline phi::DataType GetAmpDestDtype( ...@@ -129,7 +129,11 @@ inline phi::DataType GetAmpDestDtype(
->count(op_name)) { ->count(op_name)) {
dst_type = phi::DataType::FLOAT32; dst_type = phi::DataType::FLOAT32;
} else { } else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); if (amp_level == paddle::imperative::AmpLevel::OD) {
dst_type = phi::DataType::FLOAT32;
} else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
}
} }
if (dst_type == amp_setting_dtype && if (dst_type == amp_setting_dtype &&
......
...@@ -31,6 +31,7 @@ enum class AmpLevel { ...@@ -31,6 +31,7 @@ enum class AmpLevel {
O1, // amp, mixed fp32-fp16 O1, // amp, mixed fp32-fp16
O2, // almost fp16 O2, // almost fp16
O3, // fp16 O3, // fp16
OD, // only conv and matmul use low precison.
}; };
std::tuple<std::unordered_set<std::string>, std::tuple<std::unordered_set<std::string>,
......
...@@ -2152,6 +2152,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -2152,6 +2152,7 @@ void BindImperative(py::module *m_ptr) {
py::enum_<paddle::imperative::AmpLevel>(m, "AmpLevel", py::arithmetic()) py::enum_<paddle::imperative::AmpLevel>(m, "AmpLevel", py::arithmetic())
.value("O0", paddle::imperative::AmpLevel::O0) .value("O0", paddle::imperative::AmpLevel::O0)
.value("OD", paddle::imperative::AmpLevel::OD)
.value("O1", paddle::imperative::AmpLevel::O1) .value("O1", paddle::imperative::AmpLevel::O1)
.value("O2", paddle::imperative::AmpLevel::O2) .value("O2", paddle::imperative::AmpLevel::O2)
.value("O3", paddle::imperative::AmpLevel::O3) .value("O3", paddle::imperative::AmpLevel::O3)
......
...@@ -91,10 +91,19 @@ BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} ...@@ -91,10 +91,19 @@ BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = set() BF16_BLACK_LIST = set()
# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32.
def white_list(): def white_list():
white_list = { white_list = {
"float16": {"O1": FP16_WHITE_LIST, "O2": FP16_WHITE_LIST}, "float16": {
"bfloat16": {"O1": BF16_WHITE_LIST, "O2": BF16_WHITE_LIST}, "OD": FP16_WHITE_LIST,
"O1": FP16_WHITE_LIST,
"O2": FP16_WHITE_LIST,
},
"bfloat16": {
"OD": BF16_WHITE_LIST,
"O1": BF16_WHITE_LIST,
"O2": BF16_WHITE_LIST,
},
} }
return white_list return white_list
...@@ -102,9 +111,10 @@ def white_list(): ...@@ -102,9 +111,10 @@ def white_list():
def black_list(): def black_list():
black_list = { black_list = {
"float16": { "float16": {
"OD": set(),
"O1": FP16_BLACK_LIST | FP16_EXTRA_BLACK_LIST, "O1": FP16_BLACK_LIST | FP16_EXTRA_BLACK_LIST,
"O2": FP16_EXTRA_BLACK_LIST, "O2": FP16_EXTRA_BLACK_LIST,
}, },
"bfloat16": {"O1": BF16_BLACK_LIST, "O2": set()}, "bfloat16": {"OD": set(), "O1": BF16_BLACK_LIST, "O2": set()},
} }
return black_list return black_list
...@@ -48,6 +48,7 @@ class AMPGlobalState: ...@@ -48,6 +48,7 @@ class AMPGlobalState:
self.model_parameters = [] self.model_parameters = []
self.use_master_grad = False self.use_master_grad = False
self.already_register_final_backward_hook = False self.already_register_final_backward_hook = False
self.amp_dtype = 'float32'
def __setattr__(self, name, val): def __setattr__(self, name, val):
self.__dict__[name] = val self.__dict__[name] = val
...@@ -320,10 +321,8 @@ def amp_guard( ...@@ -320,10 +321,8 @@ def amp_guard(
# check amp_level: O0-O2 # check amp_level: O0-O2
level = level.upper() level = level.upper()
if not (level in ['O0', 'O1', 'O2']): if not (level in ['O0', 'OD', 'O1', 'O2']):
raise ValueError( raise ValueError("level should be O0, OD, O1 or O2.")
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
# check amp_dtype: float16 or bfloat16 # check amp_dtype: float16 or bfloat16
dtype = dtype.lower() dtype = dtype.lower()
...@@ -384,8 +383,11 @@ def amp_guard( ...@@ -384,8 +383,11 @@ def amp_guard(
) )
amp_dtype = dtype amp_dtype = dtype
amp_global_state().amp_dtype = amp_dtype
if level == 'O1': if level == 'OD':
amp_level = AMP_LEVEL.OD
elif level == 'O1':
amp_level = AMP_LEVEL.O1 amp_level = AMP_LEVEL.O1
elif level == 'O2': elif level == 'O2':
amp_level = AMP_LEVEL.O2 amp_level = AMP_LEVEL.O2
...@@ -642,22 +644,24 @@ def auto_cast( ...@@ -642,22 +644,24 @@ def auto_cast(
): ):
""" """
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32 or float16) of each operator is decided If enabled, the input data type (float32, float16 or bfloat16) of each operator is decided
by autocast algorithm for better performance. by autocast algorithm for better performance.
Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in Commonly, it is used together with `GradScaler` and `decorator` to achieve Auto-Mixed-Precision in
imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode. imperative mode.
Args: Args:
enable(bool, optional): Enable auto-mixed-precision or not. Default is True. enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support custom_white_list(set|list|tuple, optional): A default white list is already set. Usually there is no need to set custom white list.
fp16 calculation and are considered numerically-safe and performance-critical. These ops The set of ops should be considered numerically-safe and performance-critical. These ops will be converted to float16/bfloat16.
will be converted to fp16. custom_black_list(set|list|tuple, optional): A default black list is already set. You can set a custom black list according to the model.
custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16 The set of ops are considered numerically-dangerous and whose effects may also be observed in downstream ops. These ops will not be
calculation and are considered numerically-dangerous and whose effects may also be converted to float16/bfloat16.
observed in downstream ops. These ops will not be converted to fp16. level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
Examples: Examples:
......
...@@ -24,6 +24,8 @@ from paddle.fluid.data_feeder import check_type ...@@ -24,6 +24,8 @@ from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import _dygraph_tracer, dygraph_only from paddle.fluid.framework import _dygraph_tracer, dygraph_only
from .auto_cast import amp_global_state
class OptimizerState(Enum): class OptimizerState(Enum):
INIT = 0 INIT = 0
...@@ -179,6 +181,18 @@ class AmpScaler: ...@@ -179,6 +181,18 @@ class AmpScaler:
""" """
check_type(var, "var", core.eager.Tensor, 'AmpScaler.scale()') check_type(var, "var", core.eager.Tensor, 'AmpScaler.scale()')
if (
self._enable
and amp_global_state().amp_dtype != 'float16'
and self._use_dynamic_loss_scaling
):
self._enable = False
self._use_dynamic_loss_scaling = False
warnings.warn(
'It is not recommended to use dynamic loss scaling for %s, so GradScaler is disable by default.'
% (amp_global_state().amp_dtype)
)
if not self._enable: if not self._enable:
return var return var
......
...@@ -194,8 +194,11 @@ class TestAutoCast(unittest.TestCase): ...@@ -194,8 +194,11 @@ class TestAutoCast(unittest.TestCase):
class TestAmpScaler(unittest.TestCase): class TestAmpScaler(unittest.TestCase):
def scale(self): def scale(self):
if not paddle.amp.is_float16_supported():
return
with fluid.dygraph.guard(): with fluid.dygraph.guard():
data = paddle.rand([10, 1024]) with paddle.amp.auto_cast(dtype='float16'):
data = paddle.rand([10, 1024])
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
scaled_data = scaler.scale(data) scaled_data = scaler.scale(data)
self.assertEqual( self.assertEqual(
...@@ -333,9 +336,9 @@ class TestAmpScaler(unittest.TestCase): ...@@ -333,9 +336,9 @@ class TestAmpScaler(unittest.TestCase):
) )
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(inp_np) data = fluid.dygraph.to_variable(inp_np)
with paddle.amp.auto_cast(dtype='float16'):
out = model(data) out = model(data)
loss = paddle.mean(out) loss = paddle.mean(out)
scaled_loss = scaler.scale(loss) scaled_loss = scaler.scale(loss)
scaled_loss.backward() scaled_loss.backward()
optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss) optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss)
...@@ -348,6 +351,8 @@ class TestAmpScaler(unittest.TestCase): ...@@ -348,6 +351,8 @@ class TestAmpScaler(unittest.TestCase):
) )
def test_nan_inf(self): def test_nan_inf(self):
if not paddle.amp.is_float16_supported():
return
self.nan_inf() self.nan_inf()
def step_update_exception(self): def step_update_exception(self):
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from amp_base_models import AmpTestBase
import paddle
class TestAutoCast(AmpTestBase):
def test_amp_OD_level(self):
conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
)
linear = paddle.nn.Linear(in_features=4, out_features=4)
with paddle.amp.auto_cast(level='OD'):
out1 = conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32'))
out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16')
out3 = linear(out2)
self.assertEqual(out1.dtype, paddle.float16)
self.assertEqual(out2.dtype, paddle.float32)
self.assertEqual(out3.dtype, paddle.float32)
class TestGradScaler(AmpTestBase):
def test_amp_grad_scaler(self):
model = paddle.nn.Conv2D(3, 2, 3)
optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=model.parameters()
)
scaler = paddle.amp.GradScaler()
data = paddle.rand([1, 3, 8, 8], dtype='float32')
paddle.amp.debugging.enable_operator_stats_collection()
with paddle.amp.auto_cast(
custom_black_list=['conv2d'], dtype='bfloat16'
):
out = model(data)
loss = out.mean()
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
paddle.amp.debugging.disable_operator_stats_collection()
op_list = paddle.fluid.core.get_low_precision_op_list()
self.assertEqual(scaler._enable, False)
self.assertEqual(scaler._use_dynamic_loss_scaling, False)
self.assertTrue('scale' not in op_list)
self.assertTrue('check_finite_and_unscale' not in op_list)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册