amp_nn.py 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from paddle import _C_ops
16 17 18 19
from paddle.fluid import core
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
20 21 22 23

__all__ = ['check_finite_and_unscale', 'update_loss_scaling']


24
def check_finite_and_unscale(x, scale, name=None, float_status=None):
25 26 27 28 29 30
    """
    Check if input X contains all finite data, if yes, scale it by input Scale.

    $$Out = X / scale$$

    If any tensor in X contains Inf or Nan, the Out will generate a indicator.
31 32
    FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
    Out should not be used, and its data may not be deterministic.
33
    Otherwise, FoundInfinite will be 0 (False).
34

35 36 37
    Args:
        x(list|tuple): The input tensors of check_finite_and_unscale operator.
        scale: The scale of check_finite_and_unscale operator.
38
        float_status(Tensor): (Only used on NPU) The float status to check overflow.
39 40 41
    """
    check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
    for e in x:
42 43 44 45 46 47
        check_variable_and_dtype(
            e,
            "x",
            ['float16', 'float32', 'float64'],
            'check_finite_and_unscale',
        )
48 49

    helper = LayerHelper("check_finite_and_unscale", **locals())
50

51 52
    found_inf = helper.create_variable_for_type_inference(dtype='bool')

53 54 55 56
    if in_dygraph_mode():
        _C_ops.check_finite_and_unscale_(x, scale, found_inf)
        return x, found_inf

57
    inputs = {'X': x, 'Scale': scale}
58
    if core.is_compiled_with_npu():
59 60 61 62 63 64
        check_variable_and_dtype(
            float_status,
            "float_status",
            ['float16', 'float32'],
            'check_finite_and_unscale',
        )
65
        inputs['FloatStatus'] = float_status
66
    outputs = {'Out': x, 'FoundInfinite': found_inf}
67 68 69
    helper.append_op(
        type='check_finite_and_unscale', inputs=inputs, outputs=outputs
    )
70 71 72 73

    return x, found_inf


74 75 76 77 78 79 80 81 82 83 84 85 86
def update_loss_scaling(
    x,
    found_inf,
    prev_loss_scaling,
    num_good_steps,
    num_bad_steps,
    incr_every_n_steps,
    decr_every_n_nan_or_inf,
    incr_ratio,
    decr_ratio,
    stop_update=False,
    name=None,
):
87
    """
88 89
    Update loss scaling according to overall gradients. If all gradients is
    finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
90 91 92 93 94
    Otherwise, loss scaling will decrease by decr_ratio after
    decr_every_n_nan_or_inf steps and each step some gradients are infinite.

    Args:
        x(list|tuple): The input tensors of update_loss_scaling operator.
95
        found_inf (Variable): A boolean variable indicates whether
96 97
                                     there is any infinite gradient.
        prev_loss_scaling (Variable): Previous loss scaling.
98
        num_good_steps (Variable): A variable accumulates good steps in which
99
                                   all gradients are finite.
100
        num_bad_steps (Variable): A variable accumulates bad steps in which
101
                                  some gradients are infinite.
102 103
        incr_every_n_steps (int): A variable represents increasing loss
                                       scaling every n consecutive steps with
104
                                       finite gradients.
105 106
        decr_every_n_nan_or_inf (int): A variable represents decreasing
                                            loss scaling every n accumulated
107
                                            steps with nan or inf gradients.
108
        incr_ratio(float): The multiplier to use when increasing the loss
109
                           scaling.
110
        decr_ratio(float): The less-than-one-multiplier to use when decreasing
111 112 113
                           loss scaling.
    """

114 115 116 117 118 119
    check_variable_and_dtype(
        prev_loss_scaling,
        "prev_loss_scaling",
        ['float32', 'float64'],
        "update_loss_scaling",
    )
120 121
    check_type(x, 'x', (tuple, list), 'update_loss_scaling')
    for e in x:
122 123 124
        check_variable_and_dtype(
            e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
        )
125
        if e.dtype == core.VarDesc.VarType.FP16:
126 127 128
            assert (
                prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
            ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
129
        else:
130 131 132
            assert (
                prev_loss_scaling.dtype == e.dtype
            ), "The dtype of prev_loss_scaling should be equal to the dtype of x."
133

134
    if in_dygraph_mode():
135 136 137 138 139 140 141 142 143 144 145 146
        _C_ops.update_loss_scaling_(
            x,
            found_inf,
            prev_loss_scaling,
            num_good_steps,
            num_bad_steps,
            incr_every_n_steps,
            decr_every_n_nan_or_inf,
            incr_ratio,
            decr_ratio,
            stop_update,
        )
147 148
        return x

149 150 151 152 153 154 155
    helper = LayerHelper("update_loss_scaling", **locals())

    inputs = {
        'X': x,
        'FoundInfinite': found_inf,
        'PrevLossScaling': prev_loss_scaling,
        'InGoodSteps': num_good_steps,
156
        'InBadSteps': num_bad_steps,
157 158 159 160 161 162
    }

    outputs = {
        'Out': x,
        'LossScaling': prev_loss_scaling,
        'OutGoodSteps': num_good_steps,
163
        'OutBadSteps': num_bad_steps,
164 165 166 167 168 169 170 171 172
    }

    attrs = {
        'incr_every_n_steps': incr_every_n_steps,
        'decr_every_n_nan_or_inf': decr_every_n_nan_or_inf,
        'incr_ratio': incr_ratio,
        'decr_ratio': decr_ratio,
    }

173 174 175 176 177
    if isinstance(stop_update, Variable):
        inputs['StopUpdate'] = stop_update
    else:
        attrs['stop_update'] = stop_update

178 179 180
    helper.append_op(
        type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs
    )
181 182

    return x