From cf64aa0ba8d222a2ce4c85cb133be0ed17310523 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Wed, 14 Jun 2023 16:22:10 +0800 Subject: [PATCH] Change Static AMP List (#54135) (#54591) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AMP 动静态图黑名单统一 --- python/paddle/static/amp/fp16_lists.py | 53 ++++++-------------------- 1 file changed, 11 insertions(+), 42 deletions(-) diff --git a/python/paddle/static/amp/fp16_lists.py b/python/paddle/static/amp/fp16_lists.py index 96ad079879a..b057f1adf21 100644 --- a/python/paddle/static/amp/fp16_lists.py +++ b/python/paddle/static/amp/fp16_lists.py @@ -15,6 +15,11 @@ import copy import logging +from paddle.amp.amp_lists import ( + FP16_BLACK_LIST, + FP16_EXTRA_BLACK_LIST, + FP16_WHITE_LIST, +) from paddle.fluid import core from paddle.fluid.log_helper import get_logger @@ -22,17 +27,9 @@ _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' ) -# lookup_table fp16 is slower than fp32, though fp16 is supported. -_extra_black_list = { - 'lookup_table', - 'lookup_table_v2', - 'scatter', - 'linear_interp_v2', - 'nearest_interp_v2', - 'bilinear_interp_v2', - 'bicubic_interp_v2', - 'trilinear_interp_v2', -} +black_list = FP16_BLACK_LIST +_extra_black_list = FP16_EXTRA_BLACK_LIST +white_list = FP16_WHITE_LIST def check_amp_dtype(dtype): @@ -131,45 +128,17 @@ def _get_unsupported_list(dtype): _only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'} -white_list = { - 'conv2d', - 'einsum', - 'matmul', - 'matmul_v2', - 'mul', -} - def _get_white_list(dtype): - white_list_for_dtype = copy.copy(white_list) + white_list_for_dtype = copy.copy(FP16_WHITE_LIST) if dtype == 'float16': white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list return white_list_for_dtype -# The set of ops that support fp16 calculation and are considered numerically- -# dangerous and whose effects may also be observed in downstream ops. -black_list = { - 'exp', - 'square', - 'log', - 'mean', - 'sum', - 'cos_sim', - 'softmax', - 'softmax_with_cross_entropy', - 'sigmoid_cross_entropy_with_logits', - 'c_softmax_with_cross_entropy', - 'cross_entropy', - 'cross_entropy2', - # default fp32 can avoid return inf when the sum value large than 65504 - 'reduce_sum', -} - - def _get_black_list(): - _black_list = copy.copy(black_list) - _black_list = _black_list | _extra_black_list + _black_list = copy.copy(FP16_BLACK_LIST) + _black_list = _black_list | FP16_EXTRA_BLACK_LIST return _black_list -- GitLab