diff --git a/python/paddle/amp/__init__.py b/python/paddle/amp/__init__.py index 60df9de03ad11957a2df2c3c9506793835a7928a..5fa8055ba233b1b860f9e4d64639e415893af898 100644 --- a/python/paddle/amp/__init__.py +++ b/python/paddle/amp/__init__.py @@ -16,10 +16,8 @@ from .auto_cast import auto_cast # noqa: F401 from .auto_cast import decorate # noqa: F401 from .auto_cast import amp_guard # noqa: F401 from .auto_cast import amp_decorate # noqa: F401 -from .auto_cast import FP16_WHITE_LIST # noqa: F401 -from .auto_cast import FP16_BLACK_LIST # noqa: F401 -from .auto_cast import PURE_FP16_WHITE_LIST # noqa: F401 -from .auto_cast import PURE_FP16_BLACK_LIST # noqa: F401 +from .amp_lists import white_list # noqa: F401 +from .amp_lists import black_list # noqa: F401 from . import grad_scaler # noqa: F401 from .grad_scaler import GradScaler # noqa: F401 diff --git a/python/paddle/amp/amp_lists.py b/python/paddle/amp/amp_lists.py new file mode 100644 index 0000000000000000000000000000000000000000..f70c8f5ed7f913214a376c308a14eb934a35a8d7 --- /dev/null +++ b/python/paddle/amp/amp_lists.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The set of ops that support fp16 calculation and are considered numerically- +# safe and performance-critical. These ops are always converted to fp16. +FP16_WHITE_LIST = { + 'conv2d', + 'matmul', + 'matmul_v2', + 'max_pool2d_with_index', + 'mul', + 'fake_quantize_dequantize_abs_max', + 'fake_quantize_dequantize_moving_average_abs_max', +} + +# The set of ops that support fp16 calculation and are considered numerically- +# dangerous and whose effects may also be observed in downstream ops. +FP16_BLACK_LIST = { + 'tan', + 'acos', + 'asin', + 'sinh', + 'cosh', + 'atanh', + 'tanh_shrink', + 'cos_sim', + 'erfinv', + 'exp', + 'expm1', + 'log', + 'log10', + 'log2', + 'reciprocal', + 'rsqrt', + 'pow', + 'square', + 'reduce_sum', + 'mean', + 'reduce_mean', + 'reduce_prod', + 'cumprod', + 'cumsum', + 'dist', + 'pnorm', + 'frobenius_norm', + 'renorm', + 'group_norm', + 'layer_norm', + 'softmax', + 'softmin', + 'softplus', + 'log_softmax', + 'softmax_with_cross_entropy', + 'sigmoid_cross_entropy_with_logits', + 'c_softmax_with_cross_entropy', + 'cross_entropy', + 'cross_entropy2', + 'nll_loss', + 'huber_loss', + 'triplet_margin_loss', + 'log_loss', + 'hsigmoid_loss', + 'margin_cross_entropy', +} + +# FP16 performance of grad op is worse than that of FP32. Use FP32 by default. +FP16_EXTRA_BLACK_LIST = { + 'linear_interp_v2', + 'nearest_interp_v2', + 'bilinear_interp_v2', + 'bicubic_interp_v2', + 'trilinear_interp_v2', + 'lookup_table', + 'lookup_table_v2', + 'scatter', + 'depthwise_conv2d', +} + +BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} +BF16_BLACK_LIST = set() + + +def white_list(): + white_list = { + "float16": {"O1": FP16_WHITE_LIST, "O2": FP16_WHITE_LIST}, + "bfloat16": {"O1": BF16_WHITE_LIST, "O2": BF16_WHITE_LIST}, + } + return white_list + + +def black_list(): + black_list = { + "float16": { + "O1": FP16_BLACK_LIST | FP16_EXTRA_BLACK_LIST, + "O2": FP16_EXTRA_BLACK_LIST, + }, + "bfloat16": {"O1": BF16_BLACK_LIST, "O2": set()}, + } + return black_list diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index e8f552607affc73c0c379644d18867a12622842b..1f82533edbfb32eb1dc3f0e606a54de4f9f74d34 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -20,45 +20,7 @@ from paddle.fluid import core from paddle.fluid.framework import _dygraph_tracer, dygraph_only from paddle.fluid.wrapped_decorator import signature_safe_contextmanager -AMP_LEVEL = core.AmpLevel - -# The set of ops that support fp16 calculation and are considered numerically- -# safe and performance-critical. These ops are always converted to fp16. -FP16_WHITE_LIST = { - 'conv2d', - 'matmul', - 'matmul_v2', - 'max_pool2d_with_index', - 'mul', - 'fake_quantize_dequantize_abs_max', - 'fake_quantize_dequantize_moving_average_abs_max', -} - -# The set of ops that support fp16 calculation and are considered numerically- -# dangerous and whose effects may also be observed in downstream ops. -FP16_BLACK_LIST = { - 'exp', - 'square', - 'log', - 'mean', - 'sum', - 'cos_sim', - 'softmax', - 'softmax_with_cross_entropy', - 'sigmoid_cross_entropy_with_logits', - 'c_softmax_with_cross_entropy', - 'cross_entropy', - 'cross_entropy2', - # default fp32 can avoid return inf when the sum value large than 65504 - 'reduce_sum', - # FP16 performance of grad op is worse than that of FP32. Use FP32 by default. - 'linear_interp_v2', - 'nearest_interp_v2', - 'bilinear_interp_v2', - 'bicubic_interp_v2', - 'trilinear_interp_v2', -} - +from .amp_lists import black_list, white_list AMP_RELATED_FLAGS = [ 'FLAGS_cudnn_exhaustive_search', @@ -72,27 +34,7 @@ AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, } -PURE_FP16_WHITE_LIST = copy.copy(FP16_WHITE_LIST) - -PURE_FP16_BLACK_LIST = { - 'lookup_table', - 'lookup_table_v2', - 'scatter', - 'scatter_grad', - # FP16 performance of grad op is worse than that of FP32. Use FP32 by default. - 'linear_interp_v2', - 'nearest_interp_v2', - 'bilinear_interp_v2', - 'bicubic_interp_v2', - 'trilinear_interp_v2', -} - -BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} -BF16_BLACK_LIST = set() - -PURE_BF16_WHITE_LIST = copy.copy(BF16_WHITE_LIST) -PURE_BF16_BLACK_LIST = set() - +AMP_LEVEL = core.AmpLevel _g_amp_state_ = None @@ -126,20 +68,12 @@ def _update_list( """ Update black and white list according to users' custom list. """ - if dtype == 'float16': - if level == 'O1': - _white_list = copy.copy(FP16_WHITE_LIST) - _black_list = copy.copy(FP16_BLACK_LIST) - else: - _white_list = copy.copy(PURE_FP16_WHITE_LIST) - _black_list = copy.copy(PURE_FP16_BLACK_LIST) - else: - if level == 'O1': - _white_list = copy.copy(BF16_WHITE_LIST) - _black_list = copy.copy(BF16_BLACK_LIST) - else: - _white_list = copy.copy(PURE_BF16_WHITE_LIST) - _black_list = copy.copy(PURE_BF16_BLACK_LIST) + if level == 'O0': + _white_list = set() + _black_list = set() + return _white_list, _black_list + _white_list = copy.copy(white_list()[dtype][level]) + _black_list = copy.copy(black_list()[dtype][level]) if custom_white_list and custom_black_list: for op_name in custom_white_list: if op_name in custom_black_list: @@ -453,34 +387,14 @@ def amp_guard( if level == 'O1': amp_level = AMP_LEVEL.O1 - if dtype == 'float16': - _white_list = FP16_WHITE_LIST - _black_list = FP16_BLACK_LIST - elif dtype == 'bfloat16': - _white_list = BF16_WHITE_LIST - _black_list = BF16_BLACK_LIST - elif level == 'O2': amp_level = AMP_LEVEL.O2 - if dtype == 'float16': - _white_list = PURE_FP16_WHITE_LIST - _black_list = PURE_FP16_BLACK_LIST - elif dtype == 'bfloat16': - _white_list = BF16_WHITE_LIST - _black_list = BF16_BLACK_LIST elif level == 'O0': amp_level = AMP_LEVEL.O0 - if dtype == 'float16': - _white_list = FP16_WHITE_LIST - _black_list = FP16_BLACK_LIST - elif dtype == 'bfloat16': - _white_list = BF16_WHITE_LIST - _black_list = BF16_BLACK_LIST - - if custom_white_list or custom_black_list: - _white_list, _black_list = _update_list( - custom_white_list, custom_black_list, level, dtype - ) + + _white_list, _black_list = _update_list( + custom_white_list, custom_black_list, level, dtype + ) if not enable: amp_level = AMP_LEVEL.O0 diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py index 7a7d65d27d55642f84b70eeb57681fab5055b85a..8d24febaff213abf2bf1e28c27c3922278665f66 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py @@ -88,8 +88,8 @@ class TestAutoCast(unittest.TestCase): def custom_op_list(self): with fluid.dygraph.guard(): tracer = fluid.framework._dygraph_tracer() - base_white_list = paddle.amp.FP16_WHITE_LIST - base_black_list = paddle.amp.FP16_BLACK_LIST + base_white_list = paddle.amp.white_list()["float16"]["O1"] + base_black_list = paddle.amp.black_list()["float16"]["O1"] with paddle.amp.amp_guard( custom_white_list=["log"], custom_black_list=["conv2d"] ): @@ -104,8 +104,8 @@ class TestAutoCast(unittest.TestCase): == (set(base_black_list) - {"log"}) | {"conv2d"} ) - base_white_list = paddle.amp.PURE_FP16_WHITE_LIST - base_black_list = paddle.amp.PURE_FP16_BLACK_LIST + base_white_list = paddle.amp.white_list()["float16"]["O2"] + base_black_list = paddle.amp.black_list()["float16"]["O2"] with paddle.amp.amp_guard( custom_white_list=["log"], custom_black_list=["conv2d"], diff --git a/test/amp/test_amp_list.py b/test/amp/test_amp_list.py index 11bcdbfd3ba6aaff4207c5797613453be0289183..9b0bf5129c36dac7f62f85fc6344a70e42e9ae47 100644 --- a/test/amp/test_amp_list.py +++ b/test/amp/test_amp_list.py @@ -14,32 +14,63 @@ import unittest +import paddle from paddle.fluid import core -from paddle.static.amp import fp16_lists -from paddle.static.amp.fp16_lists import AutoMixedPrecisionLists +from paddle.static.amp import AutoMixedPrecisionLists, fp16_lists class TestAMPList(unittest.TestCase): - def test_main(self): - custom_white_list = [ - 'lookup_table', - 'lookup_table_v2', - ] - amp_list = AutoMixedPrecisionLists(custom_white_list=custom_white_list) - for op in custom_white_list: - self.assertTrue(op in amp_list.white_list) - self.assertTrue(op not in amp_list.black_list) - self.assertTrue(op not in amp_list.unsupported_list) - - default_black_list = [ + def setUp(self): + self.default_black_list = [ 'linear_interp_v2', 'nearest_interp_v2', 'bilinear_interp_v2', 'bicubic_interp_v2', 'trilinear_interp_v2', ] - for op in default_black_list: - self.assertTrue(op in amp_list.black_list) + self.custom_white_list = [ + 'lookup_table', + 'lookup_table_v2', + ] + + def check_if_op_in_list(self, op_list, amp_list): + for op in op_list: + self.assertTrue(op in amp_list) + + def check_if_op_not_in_list(self, op_list, amp_list): + for op in op_list: + self.assertTrue(op not in amp_list) + + def test_static(self): + amp_list = AutoMixedPrecisionLists( + custom_white_list=self.custom_white_list + ) + self.check_if_op_in_list(self.default_black_list, amp_list.black_list) + self.check_if_op_in_list(self.custom_white_list, amp_list.white_list) + self.check_if_op_not_in_list( + self.custom_white_list, amp_list.black_list + ) + self.check_if_op_not_in_list( + self.custom_white_list, amp_list.unsupported_list + ) + + def test_eager(self): + if not paddle.amp.is_float16_supported(): + return + white_list = paddle.amp.white_list() + black_list = paddle.amp.black_list() + self.check_if_op_in_list( + self.default_black_list, black_list["float16"]["O2"] + ) + self.check_if_op_not_in_list(['log', 'elementwise_add'], white_list) + with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}): + out1 = paddle.rand([2, 3]) + paddle.rand([2, 3]) + out2 = out1.mean() + out3 = paddle.log(out2) + self.check_if_op_not_in_list(['log', 'elementwise_add'], white_list) + self.assertEqual(out1.dtype, paddle.float16) + self.assertEqual(out2.dtype, paddle.float32) + self.assertEqual(out3.dtype, paddle.float32) def test_apis(self): def _run_check_dtype():