未验证 提交 41e90283 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP]expand blacklists for amp training (#50940)

上级 5e1ee106
......@@ -16,10 +16,8 @@ from .auto_cast import auto_cast # noqa: F401
from .auto_cast import decorate # noqa: F401
from .auto_cast import amp_guard # noqa: F401
from .auto_cast import amp_decorate # noqa: F401
from .auto_cast import FP16_WHITE_LIST # noqa: F401
from .auto_cast import FP16_BLACK_LIST # noqa: F401
from .auto_cast import PURE_FP16_WHITE_LIST # noqa: F401
from .auto_cast import PURE_FP16_BLACK_LIST # noqa: F401
from .amp_lists import white_list # noqa: F401
from .amp_lists import black_list # noqa: F401
from . import grad_scaler # noqa: F401
from .grad_scaler import GradScaler # noqa: F401
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
FP16_WHITE_LIST = {
'conv2d',
'matmul',
'matmul_v2',
'max_pool2d_with_index',
'mul',
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST = {
'tan',
'acos',
'asin',
'sinh',
'cosh',
'atanh',
'tanh_shrink',
'cos_sim',
'erfinv',
'exp',
'expm1',
'log',
'log10',
'log2',
'reciprocal',
'rsqrt',
'pow',
'square',
'reduce_sum',
'mean',
'reduce_mean',
'reduce_prod',
'cumprod',
'cumsum',
'dist',
'pnorm',
'frobenius_norm',
'renorm',
'group_norm',
'layer_norm',
'softmax',
'softmin',
'softplus',
'log_softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
'nll_loss',
'huber_loss',
'triplet_margin_loss',
'log_loss',
'hsigmoid_loss',
'margin_cross_entropy',
}
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
FP16_EXTRA_BLACK_LIST = {
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
'lookup_table',
'lookup_table_v2',
'scatter',
'depthwise_conv2d',
}
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = set()
def white_list():
white_list = {
"float16": {"O1": FP16_WHITE_LIST, "O2": FP16_WHITE_LIST},
"bfloat16": {"O1": BF16_WHITE_LIST, "O2": BF16_WHITE_LIST},
}
return white_list
def black_list():
black_list = {
"float16": {
"O1": FP16_BLACK_LIST | FP16_EXTRA_BLACK_LIST,
"O2": FP16_EXTRA_BLACK_LIST,
},
"bfloat16": {"O1": BF16_BLACK_LIST, "O2": set()},
}
return black_list
......@@ -20,45 +20,7 @@ from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer, dygraph_only
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
AMP_LEVEL = core.AmpLevel
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
FP16_WHITE_LIST = {
'conv2d',
'matmul',
'matmul_v2',
'max_pool2d_with_index',
'mul',
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}
from .amp_lists import black_list, white_list
AMP_RELATED_FLAGS = [
'FLAGS_cudnn_exhaustive_search',
......@@ -72,27 +34,7 @@ AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
}
PURE_FP16_WHITE_LIST = copy.copy(FP16_WHITE_LIST)
PURE_FP16_BLACK_LIST = {
'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = set()
PURE_BF16_WHITE_LIST = copy.copy(BF16_WHITE_LIST)
PURE_BF16_BLACK_LIST = set()
AMP_LEVEL = core.AmpLevel
_g_amp_state_ = None
......@@ -126,20 +68,12 @@ def _update_list(
"""
Update black and white list according to users' custom list.
"""
if dtype == 'float16':
if level == 'O1':
_white_list = copy.copy(FP16_WHITE_LIST)
_black_list = copy.copy(FP16_BLACK_LIST)
else:
_white_list = copy.copy(PURE_FP16_WHITE_LIST)
_black_list = copy.copy(PURE_FP16_BLACK_LIST)
else:
if level == 'O1':
_white_list = copy.copy(BF16_WHITE_LIST)
_black_list = copy.copy(BF16_BLACK_LIST)
else:
_white_list = copy.copy(PURE_BF16_WHITE_LIST)
_black_list = copy.copy(PURE_BF16_BLACK_LIST)
if level == 'O0':
_white_list = set()
_black_list = set()
return _white_list, _black_list
_white_list = copy.copy(white_list()[dtype][level])
_black_list = copy.copy(black_list()[dtype][level])
if custom_white_list and custom_black_list:
for op_name in custom_white_list:
if op_name in custom_black_list:
......@@ -453,34 +387,14 @@ def amp_guard(
if level == 'O1':
amp_level = AMP_LEVEL.O1
if dtype == 'float16':
_white_list = FP16_WHITE_LIST
_black_list = FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O2':
amp_level = AMP_LEVEL.O2
if dtype == 'float16':
_white_list = PURE_FP16_WHITE_LIST
_black_list = PURE_FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O0':
amp_level = AMP_LEVEL.O0
if dtype == 'float16':
_white_list = FP16_WHITE_LIST
_black_list = FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
if custom_white_list or custom_black_list:
_white_list, _black_list = _update_list(
custom_white_list, custom_black_list, level, dtype
)
_white_list, _black_list = _update_list(
custom_white_list, custom_black_list, level, dtype
)
if not enable:
amp_level = AMP_LEVEL.O0
......
......@@ -88,8 +88,8 @@ class TestAutoCast(unittest.TestCase):
def custom_op_list(self):
with fluid.dygraph.guard():
tracer = fluid.framework._dygraph_tracer()
base_white_list = paddle.amp.FP16_WHITE_LIST
base_black_list = paddle.amp.FP16_BLACK_LIST
base_white_list = paddle.amp.white_list()["float16"]["O1"]
base_black_list = paddle.amp.black_list()["float16"]["O1"]
with paddle.amp.amp_guard(
custom_white_list=["log"], custom_black_list=["conv2d"]
):
......@@ -104,8 +104,8 @@ class TestAutoCast(unittest.TestCase):
== (set(base_black_list) - {"log"}) | {"conv2d"}
)
base_white_list = paddle.amp.PURE_FP16_WHITE_LIST
base_black_list = paddle.amp.PURE_FP16_BLACK_LIST
base_white_list = paddle.amp.white_list()["float16"]["O2"]
base_black_list = paddle.amp.black_list()["float16"]["O2"]
with paddle.amp.amp_guard(
custom_white_list=["log"],
custom_black_list=["conv2d"],
......
......@@ -14,32 +14,63 @@
import unittest
import paddle
from paddle.fluid import core
from paddle.static.amp import fp16_lists
from paddle.static.amp.fp16_lists import AutoMixedPrecisionLists
from paddle.static.amp import AutoMixedPrecisionLists, fp16_lists
class TestAMPList(unittest.TestCase):
def test_main(self):
custom_white_list = [
'lookup_table',
'lookup_table_v2',
]
amp_list = AutoMixedPrecisionLists(custom_white_list=custom_white_list)
for op in custom_white_list:
self.assertTrue(op in amp_list.white_list)
self.assertTrue(op not in amp_list.black_list)
self.assertTrue(op not in amp_list.unsupported_list)
default_black_list = [
def setUp(self):
self.default_black_list = [
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
]
for op in default_black_list:
self.assertTrue(op in amp_list.black_list)
self.custom_white_list = [
'lookup_table',
'lookup_table_v2',
]
def check_if_op_in_list(self, op_list, amp_list):
for op in op_list:
self.assertTrue(op in amp_list)
def check_if_op_not_in_list(self, op_list, amp_list):
for op in op_list:
self.assertTrue(op not in amp_list)
def test_static(self):
amp_list = AutoMixedPrecisionLists(
custom_white_list=self.custom_white_list
)
self.check_if_op_in_list(self.default_black_list, amp_list.black_list)
self.check_if_op_in_list(self.custom_white_list, amp_list.white_list)
self.check_if_op_not_in_list(
self.custom_white_list, amp_list.black_list
)
self.check_if_op_not_in_list(
self.custom_white_list, amp_list.unsupported_list
)
def test_eager(self):
if not paddle.amp.is_float16_supported():
return
white_list = paddle.amp.white_list()
black_list = paddle.amp.black_list()
self.check_if_op_in_list(
self.default_black_list, black_list["float16"]["O2"]
)
self.check_if_op_not_in_list(['log', 'elementwise_add'], white_list)
with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}):
out1 = paddle.rand([2, 3]) + paddle.rand([2, 3])
out2 = out1.mean()
out3 = paddle.log(out2)
self.check_if_op_not_in_list(['log', 'elementwise_add'], white_list)
self.assertEqual(out1.dtype, paddle.float16)
self.assertEqual(out2.dtype, paddle.float32)
self.assertEqual(out3.dtype, paddle.float32)
def test_apis(self):
def _run_check_dtype():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册