fp16_utils.py 14.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from ... import core
from ... import layers
from ... import framework


def append_cast_op(i, o, prog):
    """
    Append a cast op in a given Program to cast input `i` to data type `o.dtype`.

    Args:
        i (Variable): The input Variable.
        o (Variable): The output Variable.
        prog (Program): The Program to append cast op.
    """
    prog.global_block().append_op(
        type="cast",
        inputs={"X": i},
        outputs={"Out": o},
        attrs={"in_dtype": i.dtype,
               "out_dtype": o.dtype})


J
Jie Fang 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
def _rename_arg(op, old_name, new_name):
    """
    If an op has old_name input and output, rename these input 
    args new_name.

    Args:
        op (Operator): Current operator.
        old_name (str): The old name of input args.
        new_name (str): The new name of input args.
    """
    op_desc = op.desc
    if isinstance(op_desc, tuple):
        op_desc = op_desc[0]
    op_desc._rename_input(old_name, new_name)
    op_desc._rename_output(old_name, new_name)


def _dtype_to_str(dtype):
    """
    Convert specific variable type to its corresponding string.

    Args:
        dtype (VarType): Variable type.
    """
    if dtype == core.VarDesc.VarType.FP16:
        return 'fp16'
    else:
        return 'fp32'


def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
    """
    Insert cast op and rename args of input and output.

    Args:
        block (Program): The block in which the operator is.
        op (Operator): The operator to insert cast op.
        idx (int): The index of current operator.
        src_dtype (VarType): The input variable dtype of cast op.
        desr_dtype (VarType): The output variable dtype of cast op.

    Returns:
        num_cast_op (int): The number of cast ops that have been inserted.
    """
    num_cast_ops = 0
    valid_types = [
        core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
        core.VarDesc.VarType.LOD_TENSOR_ARRAY
    ]
88

J
Jie Fang 已提交
89
    for in_name in op.input_names:
90 91 92
        if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm':
            if in_name != 'X':
                continue
J
Jie Fang 已提交
93 94 95 96 97
        for in_var_name in op.input(in_name):
            in_var = block.var(in_var_name)
            if in_var.type not in valid_types:
                continue
            if in_var.dtype == src_dtype:
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
                cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
                out_var = block.vars.get(cast_name)
                if out_var is None or out_var.dtype != dest_dtype:
                    out_var = block.create_var(
                        name=cast_name,
                        dtype=dest_dtype,
                        persistable=False,
                        stop_gradient=False)

                    block._insert_op(
                        idx,
                        type="cast",
                        inputs={"X": in_var},
                        outputs={"Out": out_var},
                        attrs={
                            "in_dtype": in_var.dtype,
                            "out_dtype": out_var.dtype
                        })
                    num_cast_ops += 1
J
Jie Fang 已提交
117 118 119 120
                _rename_arg(op, in_var.name, out_var.name)
            else:
                if op.has_attr('in_dtype'):
                    op._set_attr('in_dtype', dest_dtype)
121
    if src_dtype == core.VarDesc.VarType.FP32:
J
Jie Fang 已提交
122
        for out_name in op.output_names:
123 124
            if op.type == 'batch_norm' and out_name != 'Y':
                continue
J
Jie Fang 已提交
125 126 127 128
            for out_var_name in op.output(out_name):
                out_var = block.var(out_var_name)
                if out_var.type not in valid_types:
                    continue
129 130
                if out_var.dtype == core.VarDesc.VarType.FP32:
                    out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
J
Jie Fang 已提交
131
                    if op.has_attr('out_dtype'):
132
                        op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
J
Jie Fang 已提交
133 134 135
    return num_cast_ops


136 137 138 139 140 141 142 143 144 145
def find_true_prev_op(ops, cur_op, var_name):
    """
    Find the true prev op that outputs var_name variable.

    Args:
        ops (list): A list of ops.
        cur_op (Operator): Current operator which has var_name variable.
        var_name (string): Variable name.
    """
    prev_op = []
J
Jie Fang 已提交
146
    for op in ops:
147 148
        if op == cur_op:
            break
J
Jie Fang 已提交
149 150 151
        for out_name in op.output_names:
            for out_var_name in op.output(out_name):
                if out_var_name == var_name:
152 153 154 155 156 157 158 159
                    prev_op.append(op)
    if prev_op:
        if not len(prev_op) == 1:
            raise ValueError("There must be only one previous op "
                             "that outputs {0} variable".format(var_name))
        else:
            return prev_op[0]
    return None
J
Jie Fang 已提交
160 161


162 163 164 165 166 167 168 169 170 171 172 173
def _is_in_black_varnames(op, amp_lists):
    for in_name in op.input_arg_names:
        if in_name in amp_lists.black_varnames:
            return True

    for out_name in op.output_arg_names:
        if out_name in amp_lists.black_varnames:
            return True

    return False


J
Jie Fang 已提交
174
def rewrite_program(main_prog, amp_lists):
J
Jie Fang 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    """
    Traverse all ops in current block and insert cast op according to 
    which set current op belongs to.

    1. When an op belongs to the black list, add it to black set
    2. When an op belongs to the white list, add it to white set
    3. When an op belongs to the gray list. If one 
       of its inputs is the output of black set op or black list op, 
       add it to black set. If all of its previous ops are not black 
       op and one of its inputs is the output of white set op or 
       white list op, add it to white set.
    4. When an op isn't in the lists, add it to black op set.
    5. Add necessary cast ops to make sure that black set op will be 
       computed in fp32 mode, while white set op will be computed in 
       fp16 mode.

    Args:
        main_prog (Program): The main program for training.
    """
    block = main_prog.global_block()
    ops = block.ops
    white_op_set = set()
    black_op_set = set()
198
    for op in ops:
199 200 201 202 203
        if amp_lists.black_varnames is not None and _is_in_black_varnames(
                op, amp_lists):
            black_op_set.add(op)
            continue

J
Jie Fang 已提交
204
        if op.type in amp_lists.black_list:
J
Jie Fang 已提交
205
            black_op_set.add(op)
J
Jie Fang 已提交
206
        elif op.type in amp_lists.white_list:
J
Jie Fang 已提交
207
            white_op_set.add(op)
J
Jie Fang 已提交
208
        elif op.type in amp_lists.gray_list:
J
Jie Fang 已提交
209 210 211 212 213 214 215 216 217 218
            is_black_op = False
            is_white_op = False
            for in_name in op.input_names:
                # if this op has inputs
                if in_name:
                    for in_var_name in op.input(in_name):
                        in_var = block.var(in_var_name)
                        # this in_var isn't the output of other op
                        if in_var.op is None:
                            continue
219 220 221 222
                        elif in_var.op is op:
                            prev_op = find_true_prev_op(ops, op, in_var_name)
                            if prev_op is None:
                                continue
J
Jie Fang 已提交
223 224 225 226
                        else:
                            prev_op = in_var.op
                        # if it's one of inputs
                        if prev_op in black_op_set or \
J
Jie Fang 已提交
227
                                prev_op.type in amp_lists.black_list:
J
Jie Fang 已提交
228
                            is_black_op = True
229
                        elif prev_op in white_op_set or \
J
Jie Fang 已提交
230
                                prev_op.type in amp_lists.white_list:
J
Jie Fang 已提交
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
                            is_white_op = True
            if is_black_op:
                black_op_set.add(op)
            elif is_white_op:
                white_op_set.add(op)
            else:
                pass
        else:
            # For numerical safe, we apply fp32 computation on ops that
            # are not determined which list they should stay.
            black_op_set.add(op)

    idx = 0
    while idx < len(ops):
        op = ops[idx]
        num_cast_ops = 0
        if op in black_op_set:
            num_cast_ops = _insert_cast_op(block, op, idx,
                                           core.VarDesc.VarType.FP16,
                                           core.VarDesc.VarType.FP32)
        elif op in white_op_set:
            num_cast_ops = _insert_cast_op(block, op, idx,
                                           core.VarDesc.VarType.FP32,
                                           core.VarDesc.VarType.FP16)
        else:
            pass

        idx += num_cast_ops + 1


261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
def update_role_var_grad(main_prog, params_grads):
    """
    Update op_role_var attr for some ops to make sure the gradients
    transfered across gpus is FP16.
    1. Check whether the op that outputs gradient is cast or not.
    2. If op is cast and gradient is FP32, remove the op_role_var
       and find the prev op which outputs FP16 gradient
    3. Update the op_role_var of the prev op.

    Args:
        main_prog (Program): The main program for training.
        params_grads (list): A list of params and grads.
    """
    block = main_prog.global_block()
    BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward
    OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
    for p, g in params_grads:
        op = g.op
        if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
            role = op.attr('op_role')
            if role & int(BACKWARD) and op.has_attr('op_role_var'):
                op.desc.remove_attr("op_role_var")
            else:
                raise ValueError("The cast op {0} must be in BACKWARD role "
                                 "and have op_role_var attr.".format(op))

            fp16_grad_name = op.input(op.input_names[0])[0]
            op_for_fp16_grad = find_true_prev_op(block.ops, op, fp16_grad_name)
            op_role_var_attr_name = \
                core.op_proto_and_checker_maker.kOpRoleVarAttrName()
            attr_val = [p.name, fp16_grad_name]
            if op_for_fp16_grad.has_attr(op_role_var_attr_name):
                attr_val.extend(op_for_fp16_grad.attr(op_role_var_attr_name))
            op_for_fp16_grad._set_attr(op_role_var_attr_name, attr_val)

            # maximize the allreduce overlap
            op._set_attr('op_role', OPTIMIZE)


J
Jie Fang 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps,
                        num_bad_steps, incr_every_n_steps,
                        decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
    """
    Update loss scaling according to overall gradients. If all gradients is 
    finite after incr_every_n_steps, loss scaling will increase by incr_ratio. 
    Otherwisw, loss scaling will decrease by decr_ratio after 
    decr_every_n_nan_or_inf steps and each step some gradients are infinite.

    Args:
        is_overall_finite (Variable): A boolean variable indicates whether 
                                     all gradients are finite.
        prev_loss_scaling (Variable): Previous loss scaling.
        num_good_steps (Variable): A variable accumulates good steps in which 
                                   all gradients are finite.
        num_bad_steps (Variable): A variable accumulates bad steps in which 
                                  some gradients are infinite.
        incr_every_n_steps (Variable): A variable represents increasing loss 
                                       scaling every n consecutive steps with 
                                       finite gradients.
        decr_every_n_nan_or_inf (Variable): A variable represents decreasing 
                                            loss scaling every n accumulated 
                                            steps with nan or inf gradients.
        incr_ratio(float): The multiplier to use when increasing the loss 
                           scaling.
        decr_ratio(float): The less-than-one-multiplier to use when decreasing 
                           loss scaling.
    """
    zero_steps = layers.fill_constant(shape=[1], dtype='int32', value=0)
    with layers.Switch() as switch:
        with switch.case(is_overall_finite):
            should_incr_loss_scaling = layers.less_than(incr_every_n_steps,
                                                        num_good_steps + 1)
            with layers.Switch() as switch1:
                with switch1.case(should_incr_loss_scaling):
                    new_loss_scaling = prev_loss_scaling * incr_ratio
                    loss_scaling_is_finite = layers.isfinite(new_loss_scaling)
                    with layers.Switch() as switch2:
                        with switch2.case(loss_scaling_is_finite):
                            layers.assign(new_loss_scaling, prev_loss_scaling)
                        with switch2.default():
                            pass
                    layers.assign(zero_steps, num_good_steps)
                    layers.assign(zero_steps, num_bad_steps)

                with switch1.default():
                    layers.increment(num_good_steps)
                    layers.assign(zero_steps, num_bad_steps)

        with switch.default():
            should_decr_loss_scaling = layers.less_than(decr_every_n_nan_or_inf,
                                                        num_bad_steps + 1)
            with layers.Switch() as switch3:
                with switch3.case(should_decr_loss_scaling):
                    new_loss_scaling = prev_loss_scaling * decr_ratio
                    static_loss_scaling = \
                        layers.fill_constant(shape=[1],
                                             dtype='float32',
                                             value=1.0)
                    less_than_one = layers.less_than(new_loss_scaling,
                                                     static_loss_scaling)
                    with layers.Switch() as switch4:
                        with switch4.case(less_than_one):
                            layers.assign(static_loss_scaling,
                                          prev_loss_scaling)
                        with switch4.default():
                            layers.assign(new_loss_scaling, prev_loss_scaling)
                    layers.assign(zero_steps, num_good_steps)
                    layers.assign(zero_steps, num_bad_steps)
                with switch3.default():
                    layers.assign(zero_steps, num_good_steps)
                    layers.increment(num_bad_steps)