fp16_lists.py 9.3 KB
Newer Older
J
Jie Fang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

J
Jie Fang 已提交
15
import copy
16
import logging
17 18

from paddle.fluid import core
19 20 21 22 23
from paddle.fluid.log_helper import get_logger

_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
J
Jie Fang 已提交
24

S
sneaxiy 已提交
25
# lookup_table fp16 is slower than fp32, though fp16 is supported.
Z
Zhang Ting 已提交
26
_extra_black_list = {
27 28 29
    'lookup_table',
    'lookup_table_v2',
    'scatter',
Z
Zhang Ting 已提交
30 31 32 33 34
    'linear_interp_v2',
    'nearest_interp_v2',
    'bilinear_interp_v2',
    'bicubic_interp_v2',
    'trilinear_interp_v2',
35
}
S
sneaxiy 已提交
36

J
Jie Fang 已提交
37

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
def check_amp_dtype(dtype):
    """
    Check amp_dtype: float16 or bfloat16
    """
    if isinstance(dtype, str):
        dtype = dtype.lower()
    if dtype not in ['float16', 'bfloat16']:
        raise ValueError(
            "If enable AMP, dtype should be 'float16' or 'bfloat16'."
        )
    return dtype


def get_low_precision_vartype(dtype):
    if isinstance(dtype, core.VarDesc.VarType):
        return dtype
    elif isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "float16":
            var_type = core.VarDesc.VarType.FP16
        elif dtype == "bfloat16":
            var_type = core.VarDesc.VarType.BF16
        else:
            raise ValueError(
                "If enable AMP, dtype should be 'float16' or 'bfloat16'."
            )
        return var_type
    else:
        raise TypeError(
            "The type of dtype is expected to be string or core.VarDesc.VarType, but recieved {}.".format(
                type(dtype)
            )
        )


def get_low_precision_dtypestr(dtype):
    if isinstance(dtype, str):
        return check_amp_dtype(dtype)
    elif isinstance(dtype, core.VarDesc.VarType):
        if dtype == core.VarDesc.VarType.FP16:
            return "float16"
        elif dtype == core.VarDesc.VarType.BF16:
            return "bfloat16"
        else:
            raise ValueError(
                "If enable AMP, dtype should be core.VarDesc.VarType.FP16 or core.VarDesc.VarType.BF16."
            )
    else:
        raise TypeError(
            "The type of dtype is expected to be string or core.VarDesc.VarType, but recieved {}.".format(
                type(dtype)
            )
        )


def _get_sys_unsupported_list(dtype):
    var_type = get_low_precision_vartype(dtype)

    # The set of ops that don't support fp16 calculation
    device = None
    if core.is_compiled_with_xpu():
        device = 'XPU'
    else:
        device = 'GPU'
102
    all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
103 104 105 106 107 108 109 110 111 112 113 114 115 116

    # sys_unsupported_list will include the following ops.
    supported_fp16_list = {
        "conditional_block",
        "conditional_block_infer",
        "select_input",
        "while",
        "cast",
        "tensor_array_to_tensor",
        "lod_array_length",
        "write_to_array",
    }
    sys_unsupported_list -= supported_fp16_list

117
    return device, sys_unsupported_list, all_ops
118 119 120 121


def _get_unsupported_list(dtype):
    # The set of ops that don't support fp16 calculation
122 123
    _, _sys_unsupported_list, _sys_all_list = _get_sys_unsupported_list(dtype)
    return _sys_unsupported_list, _sys_all_list
124 125


126 127 128 129 130 131 132 133 134 135
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.

# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.

_only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'}

white_list = {
    'conv2d',
136
    'einsum',
137 138 139 140 141 142 143 144 145 146 147 148 149
    'matmul',
    'matmul_v2',
    'mul',
}


def _get_white_list(dtype):
    white_list_for_dtype = copy.copy(white_list)
    if dtype == 'float16':
        white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list
    return white_list_for_dtype


Z
Zhang Ting 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
    'exp',
    'square',
    'log',
    'mean',
    'sum',
    'cos_sim',
    'softmax',
    'softmax_with_cross_entropy',
    'sigmoid_cross_entropy_with_logits',
    'c_softmax_with_cross_entropy',
    'cross_entropy',
    'cross_entropy2',
    # default fp32 can avoid return inf when the sum value large than 65504
    'reduce_sum',
}


def _get_black_list():
    _black_list = copy.copy(black_list)
    _black_list = _black_list | _extra_black_list
    return _black_list


176
class AutoMixedPrecisionLists:
J
Jie Fang 已提交
177 178 179 180
    """
    AutoMixedPrecisionLists is a class for black/white list. It can update
    pre-defined black list and white list according to users' custom black
    white lists. The lists are used for an algorithm which determines op's
181
    execution mode (fp32, fp16 or bf16).
J
Jie Fang 已提交
182 183 184 185

    Args:
        custom_white_list (set): Users' custom white list.
        custom_black_list (set): Users' custom black list.
H
huangxu96 已提交
186
        custom_black_varnames (set): Users' custom black varibles' names.
187
        dtype (str): the low precision dtype, which can be set to 'float16' or 'bfloat16'.
J
Jie Fang 已提交
188 189
    """

190 191 192 193 194
    def __init__(
        self,
        custom_white_list=None,
        custom_black_list=None,
        custom_black_varnames=None,
195
        dtype="float16",
196
    ):
197
        self.amp_dtype = check_amp_dtype(dtype)
J
Jie Fang 已提交
198 199
        self._custom_white_list = custom_white_list
        self._custom_black_list = custom_black_list
200
        self.white_list = copy.copy(_get_white_list(self.amp_dtype))
Z
Zhang Ting 已提交
201
        self.black_list = copy.copy(_get_black_list())
J
Jie Fang 已提交
202
        self.gray_list = copy.copy(gray_list)
203 204 205
        unsupported_list, sys_all_list = _get_unsupported_list(self.amp_dtype)
        self.unsupported_list = copy.copy(unsupported_list)
        self.all_list = copy.copy(sys_all_list)
206
        self.black_varnames = copy.copy(custom_black_varnames)
J
Jie Fang 已提交
207 208 209 210 211 212
        self._update_list()

    def _update_list(self):
        """
        Update black and white list according to users' custom list.
        """
213 214 215
        _logger.debug(f"---- custom_white_list {self._custom_white_list} ---- ")
        _logger.debug(f"---- custom_black_list {self._custom_black_list} ---- ")
        _logger.debug(f"---- custom_black_varnames {self.black_varnames} ---- ")
216 217 218
        if self._custom_white_list and self._custom_black_list:
            for op_name in self._custom_white_list:
                if op_name in self._custom_black_list:
219
                    raise ValueError(
220
                        f"The given custom_white_list overlaps custom_black_list with < {op_name} >!"
221
                    )
J
Jie Fang 已提交
222 223 224 225
        if self._custom_white_list:
            for op_name in self._custom_white_list:
                if op_name in self.black_list:
                    self.black_list.remove(op_name)
226 227
                elif op_name in self.gray_list:
                    self.gray_list.remove(op_name)
J
Jie Fang 已提交
228 229 230 231 232
                self.white_list.add(op_name)
        if self._custom_black_list:
            for op_name in self._custom_black_list:
                if op_name in self.white_list:
                    self.white_list.remove(op_name)
233 234
                elif op_name in self.gray_list:
                    self.gray_list.remove(op_name)
J
Jie Fang 已提交
235
                self.black_list.add(op_name)
236
                self.unsupported_list.add(op_name)
237 238 239
        device, sys_unsupported_list, _ = _get_sys_unsupported_list(
            self.amp_dtype
        )
240 241 242 243 244 245 246 247
        actual_unsupported_list = []
        for op_name in sys_unsupported_list:
            if op_name in self.white_list:
                actual_unsupported_list.append(op_name)
        if len(actual_unsupported_list) > 0:
            _logger.warning(
                f"On current {device}, {self.amp_dtype} is not supported for operators < {actual_unsupported_list} > in white_list!"
            )
J
Jie Fang 已提交
248 249


250
# This set contains two types of ops. All ops supported fp16 calculation. One
J
Jie Fang 已提交
251
# of two types is considered numerically-safe, but may be made unsafe by an
Z
Zhen Wang 已提交
252
# upstream blacklist op. Another type do not have numerically-significant
J
Jie Fang 已提交
253 254 255 256 257 258 259 260 261 262 263
# effects, like stack, flatten2.
gray_list = {
    'elementwise_add',
    'elementwise_sub',
    'elementwise_mul',
    'elementwise_div',
    'elementwise_max',
    'elementwise_min',
    'elementwise_pow',
    'elementwise_mod',
    'elementwise_floordiv',
264
    'batch_norm',
F
furnace 已提交
265
    'layer_norm',
J
Jie Fang 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278
    'tanh',
    'sigmoid',
    'top_k',
    'pool2d',
    'pool3d',
    'dropout',
    'relu',
    'relu6',
    'leaky_relu',
    'soft_relu',
    'flatten2',
    'stack',
    'unstack',
F
furnace 已提交
279
    'uniform_random',
J
Jie Fang 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292
    'uniform_random_batch_size_like',
    'gaussian_random',
    'gaussian_random_batch_size_like',
    'slice',
    'rank',
    'scale',
    'transpose2',
    'reshape2',
    'gather',
    'fill_constant',
    'get_tensor_from_selected_rows',
    'sign',
    'cast',
Z
Zhang Ting 已提交
293
    'fused_bn_add_activation',
294
    'c_identity',
295 296
    'c_concat',
    'c_allreduce_sum',
297 298
    'concat',
    'split',
299 300
    'fused_feedforward',
    'fused_attention',
301
    'fused_multi_transformer',
J
Jie Fang 已提交
302
}
303

H
huangxu96 已提交
304
CustomOpLists = AutoMixedPrecisionLists