From 2bf61284890b43c8f4e27643690a5e407c250879 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Mon, 8 May 2023 11:04:25 +0800 Subject: [PATCH] [AMP] fix static promote (#53439) --- python/paddle/static/amp/fp16_lists.py | 68 +++++++++---------- python/paddle/static/amp/fp16_utils.py | 20 ++++-- test/amp/test_amp_list.py | 7 +- test/amp/test_amp_promote.py | 1 - test/amp/test_model_cast_to_bf16.py | 8 ++- .../contrib/test_image_classification_fp16.py | 35 ++++++++-- 6 files changed, 87 insertions(+), 52 deletions(-) diff --git a/python/paddle/static/amp/fp16_lists.py b/python/paddle/static/amp/fp16_lists.py index 6e0a4a5254c..79abcc5e6c5 100644 --- a/python/paddle/static/amp/fp16_lists.py +++ b/python/paddle/static/amp/fp16_lists.py @@ -23,11 +23,15 @@ _logger = get_logger( ) # lookup_table fp16 is slower than fp32, though fp16 is supported. -_extra_unsupported_list = { +_extra_black_list = { 'lookup_table', 'lookup_table_v2', 'scatter', - 'scatter_grad', + 'linear_interp_v2', + 'nearest_interp_v2', + 'bilinear_interp_v2', + 'bicubic_interp_v2', + 'trilinear_interp_v2', } @@ -118,8 +122,7 @@ def _get_sys_unsupported_list(dtype): def _get_unsupported_list(dtype): # The set of ops that don't support fp16 calculation _, _sys_unsupported_list = _get_sys_unsupported_list(dtype) - unsupported_list = _extra_unsupported_list | _sys_unsupported_list - return unsupported_list + return _sys_unsupported_list # The three sets listed below are changed dynamiclly. They don't contain all @@ -145,6 +148,32 @@ def _get_white_list(dtype): return white_list_for_dtype +# The set of ops that support fp16 calculation and are considered numerically- +# dangerous and whose effects may also be observed in downstream ops. +black_list = { + 'exp', + 'square', + 'log', + 'mean', + 'sum', + 'cos_sim', + 'softmax', + 'softmax_with_cross_entropy', + 'sigmoid_cross_entropy_with_logits', + 'c_softmax_with_cross_entropy', + 'cross_entropy', + 'cross_entropy2', + # default fp32 can avoid return inf when the sum value large than 65504 + 'reduce_sum', +} + + +def _get_black_list(): + _black_list = copy.copy(black_list) + _black_list = _black_list | _extra_black_list + return _black_list + + class AutoMixedPrecisionLists: """ AutoMixedPrecisionLists is a class for black/white list. It can update @@ -170,7 +199,7 @@ class AutoMixedPrecisionLists: self._custom_white_list = custom_white_list self._custom_black_list = custom_black_list self.white_list = copy.copy(_get_white_list(self.amp_dtype)) - self.black_list = copy.copy(black_list) + self.black_list = copy.copy(_get_black_list()) self.gray_list = copy.copy(gray_list) self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) self.black_varnames = copy.copy(custom_black_varnames) @@ -196,8 +225,6 @@ class AutoMixedPrecisionLists: elif op_name in self.gray_list: self.gray_list.remove(op_name) self.white_list.add(op_name) - if op_name in _extra_unsupported_list: - self.unsupported_list.remove(op_name) if self._custom_black_list: for op_name in self._custom_black_list: if op_name in self.white_list: @@ -217,33 +244,6 @@ class AutoMixedPrecisionLists: ) -# The set of ops that support fp16 calculation and are considered numerically- -# dangerous and whose effects may also be observed in downstream ops. -black_list = { - 'exp', - 'square', - 'log', - 'mean', - 'sum', - 'cos_sim', - 'softmax', - 'softmax_with_cross_entropy', - 'sigmoid_cross_entropy_with_logits', - 'c_softmax_with_cross_entropy', - 'cross_entropy', - 'cross_entropy2', - # fp16 is slower than fp32, though fp16 is supported. - 'lookup_table', - 'lookup_table_v2', - 'linear_interp_v2', - 'nearest_interp_v2', - 'bilinear_interp_v2', - 'bicubic_interp_v2', - 'trilinear_interp_v2', - # default fp32 can avoid return inf when the sum value large than 65504 - 'reduce_sum', -} - # This set contains two types of ops. All ops supported fp16 calculation. One # of two types is considered numerically-safe, but may be made unsafe by an # upstream blacklist op. Another type do not have numerically-significant diff --git a/python/paddle/static/amp/fp16_utils.py b/python/paddle/static/amp/fp16_utils.py index 740930769cb..c2f8a12d33b 100644 --- a/python/paddle/static/amp/fp16_utils.py +++ b/python/paddle/static/amp/fp16_utils.py @@ -425,15 +425,22 @@ def set_var_dst_dtype( def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level): - if level == "O1": - return keep_fp32_var_names = set() + if level == "O1": + return keep_fp32_var_names all_parameters = [] for block in program.blocks: all_parameters.extend(block.all_parameters()) ops = block.ops for op in ops: - if op_need_keep_fp32(op, amp_lists, use_fp16_guard): + # Currently, lookup_table is in black_list and unsupport_list, it's weight will be + # set to fp32 in setp 1 of cast_model_tp_fp16. But the weight may be used as matmul's + # input in transformer, so the weight is also in to_fp16_var_names. + # TODO(zhangting2020): consider fix auto_parallel_fp16 and remove lookup_table + # from black_list and unsupport_list. + if op in ['lookup_table', 'lookup_table_v2']: + continue + if _need_keep_fp32(op, amp_lists.unsupported_list, use_fp16_guard): for in_name in op.input_names: keep_fp32_var_names = keep_fp32_var_names.union( op.input(in_name) @@ -451,6 +458,7 @@ def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level): if param.name not in keep_fp32_var_names: _logger.debug(f"-- set param {param.name} to {dtype} --.") param.desc.set_dtype(dtype) + return keep_fp32_var_names def op_need_keep_fp32(op, amp_lists, use_fp16_guard): @@ -607,15 +615,17 @@ def cast_model_to_fp16( keep_fp32_ops = set() keep_fp16_ops = set() to_fp16_var_names = set() + keep_fp32_var_names = set() # step 1: set params dtype. - set_param_dtype( + fp32_var_names = set_param_dtype( program, dtype=dest_type, amp_lists=amp_lists, use_fp16_guard=use_fp16_guard, level=level, ) + keep_fp32_var_names = keep_fp32_var_names.union(fp32_var_names) def need_process(op): need_process = True @@ -719,6 +729,8 @@ def cast_model_to_fp16( idx += num_cast_ops + 1 _logger.debug("---- after cast model to fp16 ----") _logger.debug(program) + + to_fp16_var_names.difference_update(keep_fp32_var_names) return to_fp16_var_names diff --git a/test/amp/test_amp_list.py b/test/amp/test_amp_list.py index 9b0bf5129c3..e61aa8281ec 100644 --- a/test/amp/test_amp_list.py +++ b/test/amp/test_amp_list.py @@ -50,9 +50,10 @@ class TestAMPList(unittest.TestCase): self.check_if_op_not_in_list( self.custom_white_list, amp_list.black_list ) - self.check_if_op_not_in_list( - self.custom_white_list, amp_list.unsupported_list - ) + if paddle.amp.is_float16_supported(): + self.check_if_op_not_in_list( + self.custom_white_list, amp_list.black_list + ) def test_eager(self): if not paddle.amp.is_float16_supported(): diff --git a/test/amp/test_amp_promote.py b/test/amp/test_amp_promote.py index 8aa5d71a79e..e75799b39c3 100644 --- a/test/amp/test_amp_promote.py +++ b/test/amp/test_amp_promote.py @@ -48,7 +48,6 @@ class TestAMPPromote(AmpTestBase): max_iters = 2 x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32") - print(main_program) losses_o1 = self.run_program( main_program, startup_program, diff --git a/test/amp/test_model_cast_to_bf16.py b/test/amp/test_model_cast_to_bf16.py index 3002b623b18..7e4de2630d4 100644 --- a/test/amp/test_model_cast_to_bf16.py +++ b/test/amp/test_model_cast_to_bf16.py @@ -265,18 +265,20 @@ class TestProgramBF16(AmpTestBase): amp.debugging.collect_operator_stats(main_program) op_stats_list = amp.debugging._get_op_stats_list(main_program) + expected_fp32_calls = {"lookup_table_v2": 1} expected_bf16_calls = { "matmul_v2": 1, "elementwise_add": 1, "dropout": 1, "lookup_table_v2": 0, - "squared_l2_norm": 2, - "adamw": 2, + "squared_l2_norm": 3, + "adamw": 3, } self._check_optimizer( main_program, expected_bf16_calls["matmul_v2"] - + expected_bf16_calls["elementwise_add"], + + expected_bf16_calls["elementwise_add"] + + expected_fp32_calls["lookup_table_v2"], ) self._check_op_calls(op_stats_list[0], expected_bf16_calls) diff --git a/test/contrib/test_image_classification_fp16.py b/test/contrib/test_image_classification_fp16.py index fb1bafdc861..0fc98c4792d 100644 --- a/test/contrib/test_image_classification_fp16.py +++ b/test/contrib/test_image_classification_fp16.py @@ -318,7 +318,10 @@ class TestImageClassification(unittest.TestCase): copy.copy(paddle.static.amp.fp16_lists.white_list) | paddle.static.amp.fp16_lists._only_supported_fp16_list ) - black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) + black_list = copy.copy( + paddle.static.amp.fp16_lists.black_list + | paddle.static.amp.fp16_lists._extra_black_list + ) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) amp_lists = paddle.static.amp.AutoMixedPrecisionLists() @@ -331,7 +334,10 @@ class TestImageClassification(unittest.TestCase): copy.copy(paddle.static.amp.fp16_lists.white_list) | paddle.static.amp.fp16_lists._only_supported_fp16_list ) - black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) + black_list = copy.copy( + paddle.static.amp.fp16_lists.black_list + | paddle.static.amp.fp16_lists._extra_black_list + ) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) # 1. w={'exp}, b=None @@ -348,7 +354,10 @@ class TestImageClassification(unittest.TestCase): copy.copy(paddle.static.amp.fp16_lists.white_list) | paddle.static.amp.fp16_lists._only_supported_fp16_list ) - black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) + black_list = copy.copy( + paddle.static.amp.fp16_lists.black_list + | paddle.static.amp.fp16_lists._extra_black_list + ) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) # 2. w={'tanh'}, b=None @@ -365,7 +374,10 @@ class TestImageClassification(unittest.TestCase): copy.copy(paddle.static.amp.fp16_lists.white_list) | paddle.static.amp.fp16_lists._only_supported_fp16_list ) - black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) + black_list = copy.copy( + paddle.static.amp.fp16_lists.black_list + | paddle.static.amp.fp16_lists._extra_black_list + ) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) # 3. w={'lstm'}, b=None @@ -381,7 +393,10 @@ class TestImageClassification(unittest.TestCase): copy.copy(paddle.static.amp.fp16_lists.white_list) | paddle.static.amp.fp16_lists._only_supported_fp16_list ) - black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) + black_list = copy.copy( + paddle.static.amp.fp16_lists.black_list + | paddle.static.amp.fp16_lists._extra_black_list + ) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) # 4. w=None, b={'conv2d'} @@ -400,7 +415,10 @@ class TestImageClassification(unittest.TestCase): copy.copy(paddle.static.amp.fp16_lists.white_list) | paddle.static.amp.fp16_lists._only_supported_fp16_list ) - black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) + black_list = copy.copy( + paddle.static.amp.fp16_lists.black_list + | paddle.static.amp.fp16_lists._extra_black_list + ) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) # 5. w=None, b={'tanh'} @@ -419,7 +437,10 @@ class TestImageClassification(unittest.TestCase): copy.copy(paddle.static.amp.fp16_lists.white_list) | paddle.static.amp.fp16_lists._only_supported_fp16_list ) - black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) + black_list = copy.copy( + paddle.static.amp.fp16_lists.black_list + | paddle.static.amp.fp16_lists._extra_black_list + ) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) # 6. w=None, b={'lstm'} -- GitLab