amp_nn.py 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Variable
18
from paddle.fluid import core
19 20 21 22

__all__ = ['check_finite_and_unscale', 'update_loss_scaling']


23
def check_finite_and_unscale(x, scale, name=None, float_status=None):
24 25 26 27 28 29
    """
    Check if input X contains all finite data, if yes, scale it by input Scale.

    $$Out = X / scale$$

    If any tensor in X contains Inf or Nan, the Out will generate a indicator.
30 31
    FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
    Out should not be used, and its data may not be deterministic.
32
    Otherwise, FoundInfinite will be 0 (False).
33

34 35 36
    Args:
        x(list|tuple): The input tensors of check_finite_and_unscale operator.
        scale: The scale of check_finite_and_unscale operator.
37
        float_status(Tensor): (Only used on NPU) The float status to check overflow.
38 39 40
    """
    check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
    for e in x:
41
        check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
42 43 44 45 46 47
                                 'check_finite_and_unscale')

    helper = LayerHelper("check_finite_and_unscale", **locals())
    found_inf = helper.create_variable_for_type_inference(dtype='bool')

    inputs = {'X': x, 'Scale': scale}
48 49 50 51 52
    if core.is_compiled_with_npu():
        check_variable_and_dtype(float_status, "float_status",
                                 ['float16', 'float32'],
                                 'check_finite_and_unscale')
        inputs['FloatStatus'] = float_status
53
    outputs = {'Out': x, 'FoundInfinite': found_inf}
54 55 56
    helper.append_op(type='check_finite_and_unscale',
                     inputs=inputs,
                     outputs=outputs)
57 58 59 60 61 62 63 64 65 66 67 68 69

    return x, found_inf


def update_loss_scaling(x,
                        found_inf,
                        prev_loss_scaling,
                        num_good_steps,
                        num_bad_steps,
                        incr_every_n_steps,
                        decr_every_n_nan_or_inf,
                        incr_ratio,
                        decr_ratio,
70
                        stop_update=False,
71 72
                        name=None):
    """
73 74
    Update loss scaling according to overall gradients. If all gradients is
    finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
75 76 77 78 79
    Otherwise, loss scaling will decrease by decr_ratio after
    decr_every_n_nan_or_inf steps and each step some gradients are infinite.

    Args:
        x(list|tuple): The input tensors of update_loss_scaling operator.
80
        found_inf (Variable): A boolean variable indicates whether
81 82
                                     there is any infinite gradient.
        prev_loss_scaling (Variable): Previous loss scaling.
83
        num_good_steps (Variable): A variable accumulates good steps in which
84
                                   all gradients are finite.
85
        num_bad_steps (Variable): A variable accumulates bad steps in which
86
                                  some gradients are infinite.
87 88
        incr_every_n_steps (int): A variable represents increasing loss
                                       scaling every n consecutive steps with
89
                                       finite gradients.
90 91
        decr_every_n_nan_or_inf (int): A variable represents decreasing
                                            loss scaling every n accumulated
92
                                            steps with nan or inf gradients.
93
        incr_ratio(float): The multiplier to use when increasing the loss
94
                           scaling.
95
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
96 97 98 99 100 101 102
                           loss scaling.
    """

    check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling",
                             ['float32', 'float64'], "update_loss_scaling")
    check_type(x, 'x', (tuple, list), 'update_loss_scaling')
    for e in x:
103
        check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
104
                                 'update_loss_scaling')
105 106 107 108 109
        if e.dtype == core.VarDesc.VarType.FP16:
            assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \
                "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
        else:
            assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

    helper = LayerHelper("update_loss_scaling", **locals())

    inputs = {
        'X': x,
        'FoundInfinite': found_inf,
        'PrevLossScaling': prev_loss_scaling,
        'InGoodSteps': num_good_steps,
        'InBadSteps': num_bad_steps
    }

    outputs = {
        'Out': x,
        'LossScaling': prev_loss_scaling,
        'OutGoodSteps': num_good_steps,
        'OutBadSteps': num_bad_steps
    }

    attrs = {
        'incr_every_n_steps': incr_every_n_steps,
        'decr_every_n_nan_or_inf': decr_every_n_nan_or_inf,
        'incr_ratio': incr_ratio,
        'decr_ratio': decr_ratio,
    }

135 136 137 138 139
    if isinstance(stop_update, Variable):
        inputs['StopUpdate'] = stop_update
    else:
        attrs['stop_update'] = stop_update

140 141 142 143
    helper.append_op(type='update_loss_scaling',
                     inputs=inputs,
                     outputs=outputs,
                     attrs=attrs)
144 145

    return x