diff --git a/python/paddle/static/amp/fp16_lists.py b/python/paddle/static/amp/fp16_lists.py index 96ad079879a50113ae41798097f9f952a94e3213..b057f1adf215260e045e2b4c635487f4e018df82 100644 --- a/python/paddle/static/amp/fp16_lists.py +++ b/python/paddle/static/amp/fp16_lists.py @@ -15,6 +15,11 @@ import copy import logging +from paddle.amp.amp_lists import ( + FP16_BLACK_LIST, + FP16_EXTRA_BLACK_LIST, + FP16_WHITE_LIST, +) from paddle.fluid import core from paddle.fluid.log_helper import get_logger @@ -22,17 +27,9 @@ _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' ) -# lookup_table fp16 is slower than fp32, though fp16 is supported. -_extra_black_list = { - 'lookup_table', - 'lookup_table_v2', - 'scatter', - 'linear_interp_v2', - 'nearest_interp_v2', - 'bilinear_interp_v2', - 'bicubic_interp_v2', - 'trilinear_interp_v2', -} +black_list = FP16_BLACK_LIST +_extra_black_list = FP16_EXTRA_BLACK_LIST +white_list = FP16_WHITE_LIST def check_amp_dtype(dtype): @@ -131,45 +128,17 @@ def _get_unsupported_list(dtype): _only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'} -white_list = { - 'conv2d', - 'einsum', - 'matmul', - 'matmul_v2', - 'mul', -} - def _get_white_list(dtype): - white_list_for_dtype = copy.copy(white_list) + white_list_for_dtype = copy.copy(FP16_WHITE_LIST) if dtype == 'float16': white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list return white_list_for_dtype -# The set of ops that support fp16 calculation and are considered numerically- -# dangerous and whose effects may also be observed in downstream ops. -black_list = { - 'exp', - 'square', - 'log', - 'mean', - 'sum', - 'cos_sim', - 'softmax', - 'softmax_with_cross_entropy', - 'sigmoid_cross_entropy_with_logits', - 'c_softmax_with_cross_entropy', - 'cross_entropy', - 'cross_entropy2', - # default fp32 can avoid return inf when the sum value large than 65504 - 'reduce_sum', -} - - def _get_black_list(): - _black_list = copy.copy(black_list) - _black_list = _black_list | _extra_black_list + _black_list = copy.copy(FP16_BLACK_LIST) + _black_list = _black_list | FP16_EXTRA_BLACK_LIST return _black_list