amp_lists.py 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
FP16_WHITE_LIST = {
    'conv2d',
19
    'einsum',
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    'matmul',
    'matmul_v2',
    'max_pool2d_with_index',
    'mul',
    'fake_quantize_dequantize_abs_max',
    'fake_quantize_dequantize_moving_average_abs_max',
}

# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST = {
    'tan',
    'acos',
    'asin',
    'sinh',
    'cosh',
    'atanh',
    'tanh_shrink',
    'cos_sim',
    'erfinv',
    'exp',
    'expm1',
    'log',
    'log10',
    'log2',
    'reciprocal',
    'rsqrt',
    'pow',
    'square',
    'reduce_sum',
    'mean',
    'reduce_mean',
    'reduce_prod',
    'cumprod',
    'cumsum',
    'dist',
    'pnorm',
    'frobenius_norm',
    'renorm',
    'group_norm',
    'layer_norm',
    'softmax',
    'softmin',
    'softplus',
    'log_softmax',
    'softmax_with_cross_entropy',
    'sigmoid_cross_entropy_with_logits',
    'c_softmax_with_cross_entropy',
    'cross_entropy',
    'cross_entropy2',
    'nll_loss',
    'huber_loss',
    'triplet_margin_loss',
    'log_loss',
    'hsigmoid_loss',
    'margin_cross_entropy',
}

78 79
# FP16/BF16 performance of grad op is worse than that of FP32. Use FP32 by default.
EXTRA_BLACK_LIST = {
80 81 82 83 84 85 86 87 88 89
    'linear_interp_v2',
    'nearest_interp_v2',
    'bilinear_interp_v2',
    'bicubic_interp_v2',
    'trilinear_interp_v2',
    'lookup_table',
    'lookup_table_v2',
    'scatter',
}

90
BF16_WHITE_LIST = {'conv2d', 'einsum', 'matmul_v2'}
91 92 93
BF16_BLACK_LIST = set()


94
# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32.
95 96
def white_list():
    white_list = {
97 98 99 100 101 102 103 104 105 106
        "float16": {
            "OD": FP16_WHITE_LIST,
            "O1": FP16_WHITE_LIST,
            "O2": FP16_WHITE_LIST,
        },
        "bfloat16": {
            "OD": BF16_WHITE_LIST,
            "O1": BF16_WHITE_LIST,
            "O2": BF16_WHITE_LIST,
        },
107 108 109 110 111 112 113
    }
    return white_list


def black_list():
    black_list = {
        "float16": {
114
            "OD": set(),
115 116 117 118 119 120 121
            "O1": FP16_BLACK_LIST | EXTRA_BLACK_LIST,
            "O2": EXTRA_BLACK_LIST,
        },
        "bfloat16": {
            "OD": set(),
            "O1": BF16_BLACK_LIST | EXTRA_BLACK_LIST,
            "O2": EXTRA_BLACK_LIST,
122 123 124
        },
    }
    return black_list