amp_nn.py 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Variable
18
from paddle.fluid import core
19 20 21 22

__all__ = ['check_finite_and_unscale', 'update_loss_scaling']


23
def check_finite_and_unscale(x, scale, name=None, float_status=None):
24 25 26 27 28 29 30 31 32
    """
    Check if input X contains all finite data, if yes, scale it by input Scale.

    $$Out = X / scale$$

    If any tensor in X contains Inf or Nan, the Out will generate a indicator.
    FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of 
    Out should not be used, and its data may not be deterministic. 
    Otherwise, FoundInfinite will be 0 (False).
33

34 35 36
    Args:
        x(list|tuple): The input tensors of check_finite_and_unscale operator.
        scale: The scale of check_finite_and_unscale operator.
37
        float_status(Tensor): (Only used on NPU) The float status to check overflow.
38 39 40
    """
    check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
    for e in x:
41
        check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
42 43 44 45 46 47
                                 'check_finite_and_unscale')

    helper = LayerHelper("check_finite_and_unscale", **locals())
    found_inf = helper.create_variable_for_type_inference(dtype='bool')

    inputs = {'X': x, 'Scale': scale}
48 49 50 51 52
    if core.is_compiled_with_npu():
        check_variable_and_dtype(float_status, "float_status",
                                 ['float16', 'float32'],
                                 'check_finite_and_unscale')
        inputs['FloatStatus'] = float_status
53
    outputs = {'Out': x, 'FoundInfinite': found_inf}
54 55 56
    helper.append_op(type='check_finite_and_unscale',
                     inputs=inputs,
                     outputs=outputs)
57 58 59 60 61 62 63 64 65 66 67 68 69

    return x, found_inf


def update_loss_scaling(x,
                        found_inf,
                        prev_loss_scaling,
                        num_good_steps,
                        num_bad_steps,
                        incr_every_n_steps,
                        decr_every_n_nan_or_inf,
                        incr_ratio,
                        decr_ratio,
70
                        stop_update=False,
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
                        name=None):
    """
    Update loss scaling according to overall gradients. If all gradients is 
    finite after incr_every_n_steps, loss scaling will increase by incr_ratio. 
    Otherwise, loss scaling will decrease by decr_ratio after
    decr_every_n_nan_or_inf steps and each step some gradients are infinite.

    Args:
        x(list|tuple): The input tensors of update_loss_scaling operator.
        found_inf (Variable): A boolean variable indicates whether 
                                     there is any infinite gradient.
        prev_loss_scaling (Variable): Previous loss scaling.
        num_good_steps (Variable): A variable accumulates good steps in which 
                                   all gradients are finite.
        num_bad_steps (Variable): A variable accumulates bad steps in which 
                                  some gradients are infinite.
        incr_every_n_steps (int): A variable represents increasing loss 
                                       scaling every n consecutive steps with 
                                       finite gradients.
        decr_every_n_nan_or_inf (int): A variable represents decreasing 
                                            loss scaling every n accumulated 
                                            steps with nan or inf gradients.
        incr_ratio(float): The multiplier to use when increasing the loss 
                           scaling.
        decr_ratio(float): The less-than-one-multiplier to use when decreasing 
                           loss scaling.
    """

    check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling",
                             ['float32', 'float64'], "update_loss_scaling")
    check_type(x, 'x', (tuple, list), 'update_loss_scaling')
    for e in x:
103
        check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
104
                                 'update_loss_scaling')
105 106 107 108 109
        if e.dtype == core.VarDesc.VarType.FP16:
            assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \
                "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
        else:
            assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

    helper = LayerHelper("update_loss_scaling", **locals())

    inputs = {
        'X': x,
        'FoundInfinite': found_inf,
        'PrevLossScaling': prev_loss_scaling,
        'InGoodSteps': num_good_steps,
        'InBadSteps': num_bad_steps
    }

    outputs = {
        'Out': x,
        'LossScaling': prev_loss_scaling,
        'OutGoodSteps': num_good_steps,
        'OutBadSteps': num_bad_steps
    }

    attrs = {
        'incr_every_n_steps': incr_every_n_steps,
        'decr_every_n_nan_or_inf': decr_every_n_nan_or_inf,
        'incr_ratio': incr_ratio,
        'decr_ratio': decr_ratio,
    }

135 136 137 138 139
    if isinstance(stop_update, Variable):
        inputs['StopUpdate'] = stop_update
    else:
        attrs['stop_update'] = stop_update

140 141 142 143
    helper.append_op(type='update_loss_scaling',
                     inputs=inputs,
                     outputs=outputs,
                     attrs=attrs)
144 145

    return x