未验证 提交 64ecdc03 编写于 作者: N niuliling123 提交者: GitHub

Change Static AMP List (#54135)

上级 2ddd0473
......@@ -15,6 +15,11 @@
import copy
import logging
from paddle.amp.amp_lists import (
FP16_BLACK_LIST,
FP16_EXTRA_BLACK_LIST,
FP16_WHITE_LIST,
)
from paddle.fluid import core
from paddle.fluid.log_helper import get_logger
......@@ -22,17 +27,9 @@ _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_black_list = {
'lookup_table',
'lookup_table_v2',
'scatter',
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}
black_list = FP16_BLACK_LIST
_extra_black_list = FP16_EXTRA_BLACK_LIST
white_list = FP16_WHITE_LIST
def check_amp_dtype(dtype):
......@@ -131,45 +128,17 @@ def _get_unsupported_list(dtype):
_only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'}
white_list = {
'conv2d',
'einsum',
'matmul',
'matmul_v2',
'mul',
}
def _get_white_list(dtype):
white_list_for_dtype = copy.copy(white_list)
white_list_for_dtype = copy.copy(FP16_WHITE_LIST)
if dtype == 'float16':
white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list
return white_list_for_dtype
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
}
def _get_black_list():
_black_list = copy.copy(black_list)
_black_list = _black_list | _extra_black_list
_black_list = copy.copy(FP16_BLACK_LIST)
_black_list = _black_list | FP16_EXTRA_BLACK_LIST
return _black_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册