auto_cast.py 16.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager, wrap_decorator
from paddle.fluid import core
import contextlib
from paddle.fluid.framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, dygraph_only, set_flags, get_flags
import warnings
import copy
22 23 24 25
import functools
import paddle
import operator
import types
26

L
Leo Chen 已提交
27 28
AMP_LEVEL = core.AmpLevel

29
__all__ = ['amp_guard', 'amp_decorate']
30 31 32 33 34 35

# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
WHITE_LIST = {
    'conv2d',
    'matmul',
L
Leo Chen 已提交
36
    'matmul_v2',
37
    'mul',
C
cc 已提交
38 39
    'fake_quantize_dequantize_abs_max',
    'fake_quantize_dequantize_moving_average_abs_max',
40 41 42 43 44 45 46 47 48 49 50 51 52 53
}

# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
BLACK_LIST = {
    'exp',
    'square',
    'log',
    'mean',
    'sum',
    'cos_sim',
    'softmax',
    'softmax_with_cross_entropy',
    'sigmoid_cross_entropy_with_logits',
54
    'c_softmax_with_cross_entropy',
55 56
    'cross_entropy',
    'cross_entropy2',
57 58
    # default fp32 can avoid return inf when the sum value large than 65504
    'reduce_sum',
59 60 61 62 63 64 65 66 67 68 69 70 71 72
}

AMP_RELATED_FLAGS = [
    'FLAGS_cudnn_exhaustive_search',
    'FLAGS_conv_workspace_size_limit',
    'FLAGS_cudnn_batchnorm_spatial_persistent',
]

AMP_RELATED_FLAGS_SETTING = {
    'FLAGS_cudnn_exhaustive_search': 1,
    'FLAGS_conv_workspace_size_limit': 1000,
    'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
}

73 74
PURE_FP16_WHITE_LIST = {' '}
PURE_FP16_BLACK_LIST = {'lookup_table', 'lookup_table_v2'}
75

76 77 78

#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
79
def _update_list(custom_white_list, custom_black_list, level='O1'):
80 81 82
    """
    Update black and white list according to users' custom list.
    """
83 84 85 86 87 88
    if level == 'O1':
        _white_list = copy.copy(WHITE_LIST)
        _black_list = copy.copy(BLACK_LIST)
    else:
        _white_list = copy.copy(PURE_FP16_WHITE_LIST)
        _black_list = copy.copy(PURE_FP16_BLACK_LIST)
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    if custom_white_list and custom_black_list:
        for op_name in custom_white_list:
            if op_name in custom_black_list:
                raise ValueError("Custom white list overlap "
                                 "custom black list")
    if custom_white_list:
        for op_name in custom_white_list:
            if op_name in _black_list:
                _black_list.remove(op_name)
            _white_list.add(op_name)
    if custom_black_list:
        for op_name in custom_black_list:
            if op_name in _white_list:
                _white_list.remove(op_name)
            _black_list.add(op_name)
    return _white_list, _black_list


107 108 109 110 111 112
def _in_amp_guard():
    """
    Judge whether current code block is in `amp_guard` context.
    """
    tracer = _dygraph_tracer()
    if tracer:
L
Leo Chen 已提交
113
        if tracer._amp_level == core.AmpLevel.O1:
114 115 116
            return True
        else:
            return False
117 118 119 120
    else:
        return False


121 122 123 124 125
def _in_pure_fp16_guard():
    tracer = _dygraph_tracer()
    return tracer and tracer._amp_level == core.AmpLevel.O2


126
@dygraph_only
127
def pure_fp16_initialize(models):
128 129 130
    for idx in range(len(models)):
        for layer in models[idx].sublayers(include_self=True):
            layer._casted_by_pure_fp16 = True
131 132 133 134
            if (layer._dtype is 'float16') or isinstance(layer, (
                    paddle.nn.BatchNorm, paddle.nn.LayerNorm)):
                continue
            layer._to_impl(dtype='float16', include_sublayers=False)
135
    return models
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154


def check_models(models):
    for model in models:
        if not isinstance(model, paddle.nn.Layer):
            raise RuntimeError(
                "Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.".
                format(type(model)))


def check_optimizers(optimizers):
    for optimizer in optimizers:
        if not isinstance(optimizer, (paddle.optimizer.Optimizer,
                                      paddle.fluid.optimizer.Optimizer)):
            raise RuntimeError(
                "Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".
                format(type(optimizer)))


155 156
@signature_safe_contextmanager
@dygraph_only
157 158 159 160
def amp_guard(enable=True,
              custom_white_list=None,
              custom_black_list=None,
              level='O1'):
161 162 163
    """
    :api_attr: imperative

164
    Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
165 166 167
    If enabled, the input data type (float32 or float16) of each operator is decided 
    by autocast algorithm for better performance. 
    
168 169
    Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in 
    imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode.
170 171 172

    Args:
        enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
173 174 175 176 177 178 179 180 181
        custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support
             fp16 calculation and are considered numerically-safe and performance-critical. These ops 
             will be converted to fp16.
        custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16
             calculation and are considered numerically-dangerous and whose effects may also be 
             observed in downstream ops. These ops will not be converted to fp16.
        level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; 
             O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)

182 183 184 185 186 187
        
    Examples:

     .. code-block:: python

        import numpy as np
188
        import paddle
189 190

        data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
191 192 193 194
        with paddle.fluid.dygraph.guard():
            conv2d = paddle.fluid.dygraph.Conv2D(3, 2, 3)
            data = paddle.fluid.dygraph.to_variable(data)
            with paddle.fluid.dygraph.amp_guard():
195 196
                conv = conv2d(data)
                print(conv.dtype) # FP16
197
            with paddle.fluid.dygraph.amp_guard(enable=False):
198 199 200 201
                conv = conv2d(data)
                print(conv.dtype) # FP32

    """
L
Leo Chen 已提交
202
    if not (level in ['O0', 'O1', 'O2']):
203
        raise ValueError(
L
Leo Chen 已提交
204
            "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16 train mode."
205 206
        )

207 208 209 210 211
    tracer = _dygraph_tracer()
    if not tracer:
        raise ValueError(
            "current_tracer is None, maybe it is not in imperative mode.")

T
taixiurong 已提交
212 213
    if enable and not (tracer._expected_place.is_gpu_place() or
                       tracer._expected_place.is_xpu_place()):
214
        warnings.warn(
T
taixiurong 已提交
215
            'amp_guard can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.'
216 217 218
            % tracer._expected_place)
        enable = False

Z
zhangbo9674 已提交
219 220 221 222 223 224 225
    if tracer._expected_place.is_gpu_place():
        prop = paddle.device.cuda.get_device_capability()
        if prop[0] < 7:
            warnings.warn(
                "AMP only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
                % (paddle.device.cuda.get_device_name(), prop[0], prop[1]))

226
    if level == 'O1':
L
Leo Chen 已提交
227
        amp_level = AMP_LEVEL.O1
228 229
        _white_list = WHITE_LIST
        _black_list = BLACK_LIST
L
Leo Chen 已提交
230
    elif level == 'O2':
L
Leo Chen 已提交
231
        amp_level = AMP_LEVEL.O2
232 233
        _white_list = PURE_FP16_WHITE_LIST
        _black_list = PURE_FP16_BLACK_LIST
L
Leo Chen 已提交
234 235 236 237
    elif level == 'O0':
        amp_level = AMP_LEVEL.O0
        _white_list = WHITE_LIST
        _black_list = BLACK_LIST
238

239 240
    if custom_white_list or custom_black_list:
        _white_list, _black_list = _update_list(custom_white_list,
241 242 243
                                                custom_black_list, level)

    if not enable:
L
Leo Chen 已提交
244
        amp_level = AMP_LEVEL.O0
245 246 247

    if tracer:
        # enable auto_cast
248 249 250
        original_amp_level = tracer._amp_level
        tracer._amp_level = amp_level

251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
        # set amp op list
        original_white_list, original_black_list = tracer._get_amp_op_list()
        tracer._set_amp_op_list(_white_list, _black_list)

        # TODO(zhiqiu) set amp related flags automatically in this guard
        # Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard,
        # batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed insise amp_guard.
        # So, users need to set related flags manually.

        # original_flags = get_flags(AMP_RELATED_FLAGS)
        # set_flags(AMP_RELATED_FLAGS_SETTING)

    # restore status
    try:
        yield
    finally:
        if tracer:
268
            tracer._amp_level = original_amp_level
269 270
            tracer._set_amp_op_list(original_white_list, original_black_list)
            # set_flags(original_flags)
271 272 273 274 275 276 277 278 279


class StateDictHook(object):
    def __init__(self, save_dtype):
        self._save_dtype = save_dtype

    def __call__(self, state_dict):
        for key in state_dict:
            param = state_dict[key]
280
            with paddle.fluid.dygraph.guard():
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
                param_applied = paddle.cast(param, self._save_dtype)
                param_applied.name = param.name
                state_dict[key] = param_applied


@dygraph_only
def amp_decorate(models,
                 optimizers=None,
                 level='O1',
                 master_weight=None,
                 save_dtype=None):
    """
    Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. 
    When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm.
    
    Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode.

    Args:
        models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
        optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
        level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing; 
             O2 represent Pure fp16, the decorator will cast all parameters of models to FP16, except BatchNorm and LayerNorm. Default is O1(amp)
        master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
        save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, float32, float64 or None.
             The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.

    Examples:

     .. code-block:: python   
        
        # required: gpu
        # Demo1: single model and optimizer:
        import paddle

        model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
        optimzier = paddle.optimizer.SGD(parameters=model.parameters())

318
        model, optimizer = paddle.fluid.dygraph.amp_decorate(models=model, optimizers=optimzier, level='O2')
319 320 321

        data = paddle.rand([10, 3, 32, 32])

322
        with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
323 324 325 326 327 328 329 330
            output = model(data)
            print(output.dtype) # FP16

        # required: gpu
        # Demo2: multi models and optimizers:
        model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
        optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())

331
        models, optimizers = paddle.fluid.dygraph.amp_decorate(models=[model, model2], optimizers=[optimzier, optimizer2], level='O2')
332 333 334

        data = paddle.rand([10, 3, 32, 32])

335
        with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
336 337 338 339
            output = models[0](data)
            output2 = models[1](data)
            print(output.dtype) # FP16
            print(output2.dtype) # FP16
340 341 342 343 344 345 346 347 348 349 350 351 352
        
        # required: gpu
        # Demo3: optimizers is None:
        model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
        optimizer3 = paddle.optimizer.Adam(parameters=model2.parameters())

        model = paddle.fluid.dygraph.amp_decorate(models=model3, level='O2')

        data = paddle.rand([10, 3, 32, 32])

        with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
            output = model(data)
            print(output.dtype) # FP16
353 354 355 356 357 358 359
    """
    if not (level in ['O1', 'O2']):
        raise ValueError(
            "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode."
        )

    if level == 'O1':
360 361 362 363
        if optimizers is None:
            return models
        else:
            return models, optimizers
364 365 366 367 368 369 370 371 372 373 374 375 376

    models_is_list = False
    if isinstance(models, paddle.nn.Layer):
        models_is_list = False
        models = [models]
        check_models(models)
    elif isinstance(models, list):
        check_models(models)
        models_is_list = True
    else:
        raise TypeError(
            "models must be either a single model or a list of models.")

377
    models = pure_fp16_initialize(models=models)
378

379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
    if optimizers is not None:
        # check optimizers
        optimizers_is_list = False
        if isinstance(optimizers, (paddle.optimizer.Optimizer,
                                   paddle.fluid.optimizer.Optimizer)):
            optimizers_is_list = False
            optimizers = [optimizers]
            check_optimizers(optimizers)
        elif isinstance(optimizers, list):
            check_optimizers(optimizers)
            optimizers_is_list = True
        else:
            raise TypeError(
                "optimizers must be either a single optimizer or a list of optimizers."
            )
        # supprot master_weight    
        for idx_opt in range(len(optimizers)):
            if hasattr(optimizers[idx_opt], '_multi_precision'):
                if master_weight is False:
                    optimizers[idx_opt]._multi_precision = False
                else:
                    optimizers[idx_opt]._multi_precision = True
401 402 403 404 405 406 407 408 409 410 411

    if save_dtype is not None:
        if not (save_dtype in ['float16', 'float32', 'float64']):
            raise ValueError(
                "save_dtype can only be float16 float32 or float64, but your input save_dtype is %s."
                % save_dtype)
        for idx in range(len(models)):
            for layer in models[idx].sublayers(include_self=True):
                layer.register_state_dict_hook(StateDictHook(save_dtype))

    if models_is_list:
412 413 414 415 416
        if optimizers is not None:
            if optimizers_is_list:
                return models, optimizers
            else:
                return models, optimizers[0]
417
        else:
418
            return models
419
    else:
420 421 422 423 424
        if optimizers is not None:
            if optimizers_is_list:
                return models[0], optimizers
            else:
                return models[0], optimizers[0]
425
        else:
426
            return models[0]