fp16_lists.py 6.0 KB
Newer Older
J
Jie Fang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

J
Jie Fang 已提交
15
import copy
16 17

from paddle.fluid import core
J
Jie Fang 已提交
18

S
sneaxiy 已提交
19
# lookup_table fp16 is slower than fp32, though fp16 is supported.
20
_extra_unsupported_fp16_list = {
21 22 23 24
    'lookup_table',
    'lookup_table_v2',
    'scatter',
    'scatter_grad',
25
}
S
sneaxiy 已提交
26

J
Jie Fang 已提交
27

28
class AutoMixedPrecisionLists:
J
Jie Fang 已提交
29 30 31 32
    """
    AutoMixedPrecisionLists is a class for black/white list. It can update
    pre-defined black list and white list according to users' custom black
    white lists. The lists are used for an algorithm which determines op's
Z
Zhen Wang 已提交
33
    execution mode (fp32 or fp16).
J
Jie Fang 已提交
34 35 36 37

    Args:
        custom_white_list (set): Users' custom white list.
        custom_black_list (set): Users' custom black list.
H
huangxu96 已提交
38
        custom_black_varnames (set): Users' custom black varibles' names.
J
Jie Fang 已提交
39 40
    """

41 42 43 44 45 46
    def __init__(
        self,
        custom_white_list=None,
        custom_black_list=None,
        custom_black_varnames=None,
    ):
J
Jie Fang 已提交
47 48 49 50 51
        self._custom_white_list = custom_white_list
        self._custom_black_list = custom_black_list
        self.white_list = copy.copy(white_list)
        self.black_list = copy.copy(black_list)
        self.gray_list = copy.copy(gray_list)
52
        self.unsupported_list = copy.copy(unsupported_fp16_list)
53
        self.black_varnames = copy.copy(custom_black_varnames)
J
Jie Fang 已提交
54 55 56 57 58 59
        self._update_list()

    def _update_list(self):
        """
        Update black and white list according to users' custom list.
        """
60 61 62
        if self._custom_white_list and self._custom_black_list:
            for op_name in self._custom_white_list:
                if op_name in self._custom_black_list:
63 64 65
                    raise ValueError(
                        "Custom white list overlap " "custom black list"
                    )
J
Jie Fang 已提交
66 67 68 69
        if self._custom_white_list:
            for op_name in self._custom_white_list:
                if op_name in self.black_list:
                    self.black_list.remove(op_name)
70 71
                elif op_name in self.gray_list:
                    self.gray_list.remove(op_name)
J
Jie Fang 已提交
72
                self.white_list.add(op_name)
S
sneaxiy 已提交
73 74
                if op_name in _extra_unsupported_fp16_list:
                    self.unsupported_list.remove(op_name)
J
Jie Fang 已提交
75 76 77 78
        if self._custom_black_list:
            for op_name in self._custom_black_list:
                if op_name in self.white_list:
                    self.white_list.remove(op_name)
79 80
                elif op_name in self.gray_list:
                    self.gray_list.remove(op_name)
J
Jie Fang 已提交
81
                self.black_list.add(op_name)
82
                self.unsupported_list.add(op_name)
J
Jie Fang 已提交
83 84


85
# The three sets listed below are changed dynamiclly. They don't contain all
J
Jie Fang 已提交
86 87 88 89 90 91 92
# paddle ops currently.

# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list = {
    'conv2d',
    'matmul',
L
Leo Chen 已提交
93
    'matmul_v2',
J
Jie Fang 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    'mul',
}

# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
    'exp',
    'square',
    'log',
    'mean',
    'sum',
    'cos_sim',
    'softmax',
    'softmax_with_cross_entropy',
    'sigmoid_cross_entropy_with_logits',
109
    'c_softmax_with_cross_entropy',
J
Jie Fang 已提交
110 111
    'cross_entropy',
    'cross_entropy2',
H
huangxu96 已提交
112 113 114
    # fp16 is slower than fp32, though fp16 is supported.
    'lookup_table',
    'lookup_table_v2',
115 116 117 118 119
    'linear_interp_v2',
    'nearest_interp_v2',
    'bilinear_interp_v2',
    'bicubic_interp_v2',
    'trilinear_interp_v2',
120 121
    # default fp32 can avoid return inf when the sum value large than 65504
    'reduce_sum',
J
Jie Fang 已提交
122 123
}

124
# This set contains two types of ops. All ops supported fp16 calculation. One
J
Jie Fang 已提交
125
# of two types is considered numerically-safe, but may be made unsafe by an
Z
Zhen Wang 已提交
126
# upstream blacklist op. Another type do not have numerically-significant
J
Jie Fang 已提交
127 128 129 130 131 132 133 134 135 136 137
# effects, like stack, flatten2.
gray_list = {
    'elementwise_add',
    'elementwise_sub',
    'elementwise_mul',
    'elementwise_div',
    'elementwise_max',
    'elementwise_min',
    'elementwise_pow',
    'elementwise_mod',
    'elementwise_floordiv',
138
    'batch_norm',
F
furnace 已提交
139
    'layer_norm',
J
Jie Fang 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152
    'tanh',
    'sigmoid',
    'top_k',
    'pool2d',
    'pool3d',
    'dropout',
    'relu',
    'relu6',
    'leaky_relu',
    'soft_relu',
    'flatten2',
    'stack',
    'unstack',
F
furnace 已提交
153
    'uniform_random',
J
Jie Fang 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166
    'uniform_random_batch_size_like',
    'gaussian_random',
    'gaussian_random_batch_size_like',
    'slice',
    'rank',
    'scale',
    'transpose2',
    'reshape2',
    'gather',
    'fill_constant',
    'get_tensor_from_selected_rows',
    'sign',
    'cast',
Z
Zhang Ting 已提交
167
    'fused_bn_add_activation',
168
    'c_identity',
169 170
    'c_concat',
    'c_allreduce_sum',
171 172
    'concat',
    'split',
173 174
    'fused_feedforward',
    'fused_attention',
175
    'fused_multi_transformer',
J
Jie Fang 已提交
176
}
177

J
Jie Fang 已提交
178
# The set of ops that don't support fp16 calculation
179
# lookup_table fp16 is slower than fp32, though fp16 is supported.
T
taixiurong 已提交
180
_sys_unsupported_fp16_list = []
T
taixiurong 已提交
181
if core.is_compiled_with_xpu():
T
taixiurong 已提交
182
    _, _, _sys_unsupported_fp16_list = core.op_supported_infos(
183 184
        'XPU', core.VarDesc.VarType.FP16
    )
K
Kim Yann 已提交
185
elif core.is_compiled_with_custom_device('npu'):
186
    _, _, _sys_unsupported_fp16_list = core.op_supported_infos(
187 188
        'NPU', core.VarDesc.VarType.FP16
    )
T
taixiurong 已提交
189 190
else:
    _, _, _sys_unsupported_fp16_list = core.op_supported_infos(
191 192
        'GPU', core.VarDesc.VarType.FP16
    )
T
taixiurong 已提交
193

194 195 196
unsupported_fp16_list = (
    _extra_unsupported_fp16_list | _sys_unsupported_fp16_list
)
H
huangxu96 已提交
197 198

CustomOpLists = AutoMixedPrecisionLists