amp_utils.py 22.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16 17 18 19 20 21 22 23 24 25
import collections
import logging
import struct

import numpy as np

from paddle.fluid import core, framework, global_scope
from paddle.fluid.log_helper import get_logger
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager

26 27 28
from ..fp16_utils import (
    _rename_arg,
    _rename_op_input,
29 30
    find_true_post_op,
    find_true_prev_op,
31
)
32
from .amp_lists import AutoMixedPrecisionListsBF16
33

34 35 36
_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
37 38

_valid_types = [
39 40 41
    core.VarDesc.VarType.LOD_TENSOR,
    core.VarDesc.VarType.SELECTED_ROWS,
    core.VarDesc.VarType.LOD_TENSOR_ARRAY,
42 43 44 45 46 47 48 49 50
]

_bf16_guard_pattern = "__use_bf16__"


def convert_float_to_uint16(in_list):
    in_list = np.asarray(in_list)
    out = np.vectorize(
        lambda x: struct.unpack('<I', struct.pack('<f', x))[0] >> 16,
51 52
        otypes=[np.uint16],
    )(in_list.flat)
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    return np.reshape(out, in_list.shape)


def _dtype_to_str(dtype):
    """
    Convert specific variable type to its corresponding string.

    Args:
        dtype (VarType): Variable type.
    """
    if dtype == core.VarDesc.VarType.BF16:
        return 'bf16'
    else:
        return 'fp32'


def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
    """
    Insert cast op and rename args of input and output.

    Args:
        block (Program): The block in which the operator is.
        op (Operator): The operator to insert cast op.
        idx (int): The index of current operator.
        src_dtype (VarType): The input variable dtype of cast op.
        dest_dtype (VarType): The output variable dtype of cast op.

    Returns:
        num_cast_op (int): The number of cast ops that have been inserted.
    """
    num_cast_ops = 0

    for in_name in op.input_names:
        if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
87 88 89
            'batch_norm',
            'fused_bn_add_activation',
            'layer_norm',
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
        ]:
            if in_name not in {'X', 'Z'}:
                continue
        for in_var_name in op.input(in_name):
            in_var = block.var(in_var_name)
            if in_var.type not in _valid_types or in_var.dtype == dest_dtype:
                continue
            if in_var.dtype == src_dtype:
                cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
                out_var = block.vars.get(cast_name)
                if out_var is None or out_var.dtype != dest_dtype:
                    out_var = block.create_var(
                        name=cast_name,
                        dtype=dest_dtype,
                        persistable=False,
105 106 107 108 109 110 111 112 113 114 115 116 117
                        stop_gradient=in_var.stop_gradient,
                    )

                    block._insert_op(
                        idx,
                        type="cast",
                        inputs={"X": in_var},
                        outputs={"Out": out_var},
                        attrs={
                            "in_dtype": in_var.dtype,
                            "out_dtype": out_var.dtype,
                        },
                    )
118 119 120 121 122
                    num_cast_ops += 1
                _rename_arg(op, in_var.name, out_var.name)
            else:
                if op.has_attr('in_dtype'):
                    op._set_attr('in_dtype', dest_dtype)
123 124 125 126
    if (
        src_dtype == core.VarDesc.VarType.FP32
        and dest_dtype == core.VarDesc.VarType.BF16
    ):
127
        for out_name in op.output_names:
128 129 130 131 132
            if (
                op.type
                in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']
                and out_name != 'Y'
            ):
133 134 135 136 137 138 139 140 141 142 143 144
                continue
            for out_var_name in op.output(out_name):
                out_var = block.var(out_var_name)
                if out_var.type not in _valid_types:
                    continue
                if out_var.dtype == core.VarDesc.VarType.FP32:
                    out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
                    if op.has_attr('out_dtype'):
                        op._set_attr('out_dtype', core.VarDesc.VarType.BF16)
    return num_cast_ops


145 146 147
def _insert_cast_post_op(
    block, op, idx, src_dtype, dest_dtype, target_name, op_var_rename_map
):
A
arlesniak 已提交
148 149 150 151 152
    num_cast_ops = 0
    target_var = block.var(target_name)
    if target_var.type not in _valid_types or target_var.dtype == dest_dtype:
        return num_cast_ops

153 154 155 156 157
    assert (
        target_var.dtype == src_dtype
    ), "The real dtype({}) is not equal to the src dtype({})".format(
        _dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)
    )
A
arlesniak 已提交
158 159 160 161

    cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
    cast_var = block.vars.get(cast_name)
    if cast_var is None or cast_var.dtype != dest_dtype:
162 163 164 165 166 167 168 169 170 171 172 173 174
        cast_var = block.create_var(
            name=cast_name,
            dtype=dest_dtype,
            persistable=False,
            stop_gradient=target_var.stop_gradient,
        )
        block._insert_op(
            idx,
            type="cast",
            inputs={"X": target_var},
            outputs={"Out": cast_var},
            attrs={"in_dtype": target_var.dtype, "out_dtype": cast_var.dtype},
        )
A
arlesniak 已提交
175 176 177 178 179 180
        num_cast_ops += 1
        op_var_rename_map[block.idx][target_var.name] = cast_var.name

    return num_cast_ops


181
def _is_in_fp32_varnames(op, amp_lists):
A
arlesniak 已提交
182 183 184
    if not amp_lists.fp32_varnames:
        return False

185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
    for in_name in op.input_arg_names:
        if in_name in amp_lists.fp32_varnames:
            return True

    for out_name in op.output_arg_names:
        if out_name in amp_lists.fp32_varnames:
            return True

    return False


def _need_keep_fp32(op, unsupported_op_list, use_bf16_guard):
    if op.type in unsupported_op_list:
        # the highest priority condition: If ops don't have bf16 computing kernels,
        # they must be executed in fp32 calculation pattern.
        return True

    # process ops about learning rate
    in_out_arg_names = []
    in_out_arg_names.extend(list(op.input_arg_names))
    in_out_arg_names.extend(list(op.output_arg_names))
    for name in in_out_arg_names:
        if "learning_rate" in name:
            return True

    if use_bf16_guard:
211 212 213
        if op.has_attr("op_namescope") and (
            _bf16_guard_pattern in op.attr("op_namescope")
        ):
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
            # op in bf16 guard
            return False
        else:
            # op not in bf16 guard
            return True
    else:
        return False


@signature_safe_contextmanager
def bf16_guard():
    """
    As for the pure bf16 training, if users set `use_bf16_guard` to True,
    only those ops created in the context manager `bf16_guard` will be
    transformed as float16 type.

    Examples:
        .. code-block:: python

            import numpy as np
            import paddle
            import paddle.nn.functional as F
            paddle.enable_static()
            data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
            conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)

            with paddle.static.amp.bf16_guard():
                bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                hidden = paddle.static.nn.fc(pool, size=10)
                loss = paddle.mean(hidden)
    """
    with framework.name_scope(prefix=_bf16_guard_pattern):
        yield


250 251 252
def are_post_ops_bf16(post_ops, keep_fp32_ops):
    for post_op in post_ops:
        for op in post_op:
253
            if op in keep_fp32_ops:
254 255 256 257
                return False
    return True


258 259 260 261 262 263 264 265
def cast_initializers_to_bf16(
    startup_prog,
    amp_lists,
    block,
    all_ops,
    keep_fp32_ops,
    to_bf16_var_names=None,
):
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
    prepend_ops = startup_prog.global_block().ops
    for op in prepend_ops:
        if str(op.type) in amp_lists.bf16_initializer_list:
            change_op = True
            op_post_ops = []
            op_out_vars = []
            for out_name in op.output_names:
                for out_var_name in op.output(out_name):
                    out_var = block.var(out_var_name)
                    post_op = find_true_post_op(all_ops, op, out_var_name, True)

                    if out_var is None or out_var.type not in _valid_types:
                        change_op = False
                        break
                    op_post_ops.append(post_op)
                    op_out_vars.append(out_var)

            if change_op and are_post_ops_bf16(op_post_ops, keep_fp32_ops):
                for out_var in op_out_vars:
                    if out_var.dtype == core.VarDesc.VarType.FP32:
                        out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
287 288 289 290
                    if (
                        to_bf16_var_names is not None
                        and out_var.name in to_bf16_var_names
                    ):
291
                        to_bf16_var_names.remove(out_var.name)
292 293 294 295
                if (
                    op.has_attr('dtype')
                    and op.attr('dtype') == core.VarDesc.VarType.FP32
                ):
296 297 298
                    op._set_attr('dtype', core.VarDesc.VarType.BF16)


299 300 301
def cast_model_to_bf16(
    program, startup_prog=None, amp_lists=None, use_bf16_guard=True
):
A
arlesniak 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
    """
    Traverse all ops in the whole model and set their inputs and outputs
    to the bf16 data type. This function will do some special processing for
    the batch normalization, which will keep the batchnorm's computations in FP32.
    Args:
        program (Program): The used program.
        amp_lists (AutoMixedPrecisionListsBF16): An AutoMixedPrecisionListsBF16 object.
        use_bf16_guard(bool): Determine whether to use `bf16_guard` when
                              constructing the program. Default True.
    """

    if amp_lists is None:
        amp_lists = AutoMixedPrecisionListsBF16()
    global_block = program.global_block()
    keep_fp32_ops = set()
    to_bf16_var_names = set()
    to_bf16_pre_cast_ops = set()
    origin_ops = []
    for block in program.blocks:
        origin_ops.extend(block.ops)

    for block in program.blocks:
        ops = block.ops
        for op in ops:
            if op.type == 'create_py_reader' or op.type == 'read':
                continue
            if _need_keep_fp32(op, amp_lists.unsupported_list, use_bf16_guard):
                keep_fp32_ops.add(op)
                continue  # processed below
            for in_name in op.input_names:
                if op.type in {
333 334 335
                    'batch_norm',
                    'fused_bn_add_activation',
                    'layer_norm',
A
arlesniak 已提交
336 337 338 339 340 341 342 343
                } and in_name not in {'X', 'Z'}:
                    continue
                for in_var_name in op.input(in_name):
                    in_var = None
                    try:
                        in_var = block.var(in_var_name)
                    except ValueError as e:
                        _logger.debug(
344 345 346 347
                            "-- {}, try to get it in the global block --".format(
                                e
                            )
                        )
A
arlesniak 已提交
348 349 350
                        in_var = global_block.var(in_var_name)
                        if in_var is not None:
                            _logger.debug(
351 352 353 354
                                "-- var {} is got in the global block --".format(
                                    in_var_name
                                )
                            )
A
arlesniak 已提交
355 356 357 358 359 360 361 362 363

                    if in_var is None or in_var.type not in _valid_types:
                        continue

                    if in_var.dtype == core.VarDesc.VarType.FP32:
                        in_var.desc.set_dtype(core.VarDesc.VarType.BF16)
                        to_bf16_var_names.add(in_var_name)

                    _logger.debug(
364 365 366 367
                        "-- op type: {}, in var name: {}, in var dtype: {} --".format(
                            op.type, in_var_name, in_var.dtype
                        )
                    )
A
arlesniak 已提交
368 369

            for out_name in op.output_names:
370 371 372 373 374
                if (
                    op.type
                    in {'batch_norm', 'fused_bn_add_activation', 'layer_norm'}
                    and out_name != 'Y'
                ):
A
arlesniak 已提交
375 376 377 378 379 380 381
                    continue
                for out_var_name in op.output(out_name):
                    out_var = None
                    try:
                        out_var = block.var(out_var_name)
                    except ValueError as e:
                        _logger.debug(
382 383 384 385
                            "-- {}, try to get it in the global block --".format(
                                e
                            )
                        )
A
arlesniak 已提交
386 387 388
                        out_var = global_block.var(out_var_name)
                        if out_var is not None:
                            _logger.debug(
389 390 391 392
                                "-- var {} is got in the global block --".format(
                                    out_var_name
                                )
                            )
A
arlesniak 已提交
393 394 395 396 397 398 399 400

                    if out_var is None or out_var.type not in _valid_types:
                        continue

                    if out_var.dtype == core.VarDesc.VarType.FP32:
                        out_var.desc.set_dtype(core.VarDesc.VarType.BF16)

                    _logger.debug(
401 402 403 404
                        "-- op type: {}, out var name: {}, out var dtype: {} --".format(
                            op.type, out_var_name, out_var.dtype
                        )
                    )
A
arlesniak 已提交
405
            for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
406 407 408 409
                if (
                    op.has_attr(attr_name)
                    and op.attr(attr_name) == core.VarDesc.VarType.FP32
                ):
A
arlesniak 已提交
410 411 412 413 414 415
                    op._set_attr(attr_name, core.VarDesc.VarType.BF16)
            if op.has_attr('use_mkldnn'):
                op._set_attr('use_mkldnn', True)
            if op.has_attr('mkldnn_data_type'):
                op._set_attr('mkldnn_data_type', 'bfloat16')

416
        if startup_prog is not None:
417 418 419 420 421 422 423 424
            cast_initializers_to_bf16(
                startup_prog,
                amp_lists,
                global_block,
                ops,
                keep_fp32_ops,
                to_bf16_var_names,
            )
425

A
arlesniak 已提交
426 427 428 429 430 431 432 433 434 435 436 437
    # process ops in keep_fp32_ops
    op_var_rename_map = [
        collections.OrderedDict() for _ in range(len(program.blocks))
    ]
    for block in program.blocks:
        ops = block.ops
        idx = 0
        while idx < len(ops):
            op = ops[idx]
            num_cast_ops = 0
            if op not in keep_fp32_ops:
                if op in to_bf16_pre_cast_ops:
438 439 440 441 442 443 444
                    in_var_cast_num = _insert_cast_op(
                        block,
                        op,
                        idx,
                        core.VarDesc.VarType.FP32,
                        core.VarDesc.VarType.BF16,
                    )
A
arlesniak 已提交
445 446
                    num_cast_ops += in_var_cast_num
            else:
447 448 449 450 451 452 453
                pre_cast_num = _insert_cast_op(
                    block,
                    op,
                    idx,
                    core.VarDesc.VarType.BF16,
                    core.VarDesc.VarType.FP32,
                )
A
arlesniak 已提交
454 455 456 457 458 459 460 461 462 463 464 465
                num_cast_ops += pre_cast_num
                for out_var_name in op.output_arg_names:
                    out_var = block.vars.get(out_var_name)
                    if out_var is None or out_var.type not in _valid_types:
                        continue
                    if out_var.dtype == core.VarDesc.VarType.BF16:
                        out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
                        post_ops = find_true_post_op(ops, op, out_var_name)
                        for post_op in post_ops:
                            if post_op in keep_fp32_ops:
                                continue
                            post_cast_num = _insert_cast_post_op(
466 467 468
                                block,
                                op,
                                idx + pre_cast_num + 1,
A
arlesniak 已提交
469
                                core.VarDesc.VarType.FP32,
470 471 472 473
                                core.VarDesc.VarType.BF16,
                                out_var_name,
                                op_var_rename_map,
                            )
A
arlesniak 已提交
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
                            num_cast_ops += post_cast_num
            idx += num_cast_ops + 1

    _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops)
    return to_bf16_var_names


def cast_parameters_to_bf16(place, program, scope=None, to_bf16_var_names=None):
    """
    Traverse all parameters in the whole model and set them to the BF16 data type.
    Whereas, this function will keep parameters of batchnorms in FP32.
    Args:
        place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the BF16 weight tensors.
        program (Program): The used program.
        scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values.
                                      Default is None.
        to_bf16_var_names(set|list, optional): The data types of vars in `to_bf16_var_names`
                                               will be set to BF16. Usually, it is the returned
                                               value of `cast_model_to_bf16` API.
    """
    all_parameters = []
    for block in program.blocks:
        all_parameters.extend(block.all_parameters())

    bf16_var_names = to_bf16_var_names if to_bf16_var_names else set()
    var_scope = scope if scope else global_scope()
    for param in all_parameters:
        if param.name in bf16_var_names:
            _logger.debug("---- cast {} to bf16 dtype ----".format(param.name))
            param_t = var_scope.find_var(param.name).get_tensor()
            data = np.array(param_t)
            param_t.set(convert_float_to_uint16(data), place)


def rewrite_program_bf16(main_prog, amp_lists=None):
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
    """
    Traverse all ops in current block and insert cast op according to
    which set current op belongs to.

    1. When an op belongs to the fp32 list, add it to fp32 set
    2. When an op belongs to the bf16 list, add it to bf16 set
    3. When an op belongs to the gray list. If one
       of its inputs is the output of fp32 set op or fp32 list op,
       add it to fp32 set. If all of its previous ops are not fp32
       op and one of its inputs is the output of bf16 set op or
       bf16 list op, add it to bf16 set.
    4. When an op isn't in the lists, add it to fp32 op set.
    5. Add necessary cast ops to make sure that fp32 set op will be
       computed in fp32 mode, while bf16 set op will be computed in
       bf16 mode.

    Args:
        main_prog (Program): The main program for training.
    """
    if amp_lists is None:
        amp_lists = AutoMixedPrecisionListsBF16()
    block = main_prog.global_block()
    ops = block.ops
    bf16_op_set = set()
    fp32_op_set = set()
    for op in ops:

        # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
        # we don't need to handle reader op and the input of 'create_py_reader' is not
        # in block, which may result in errors.
        # See GeneratorLoader._init_non_iterable() for details.
        if op.type == 'create_py_reader' or op.type == 'read':
            continue

        if amp_lists.fp32_varnames is not None and _is_in_fp32_varnames(
544 545
            op, amp_lists
        ):
546 547 548
            fp32_op_set.add(op)
            continue

A
arlesniak 已提交
549
        if op.type in amp_lists.fp32_list:
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
            fp32_op_set.add(op)
        elif op.type in amp_lists.bf16_list:
            bf16_op_set.add(op)
        elif op.type in amp_lists.gray_list:
            is_fp32_op = False
            is_bf16_op = False
            for in_name in op.input_names:
                # if this op has inputs
                if in_name:
                    for in_var_name in op.input(in_name):
                        in_var = block.var(in_var_name)
                        # this in_var isn't the output of other op
                        if in_var.op is None:
                            continue
                        elif in_var.op is op:
                            prev_op = find_true_prev_op(ops, op, in_var_name)
                            if prev_op is None:
                                continue
                        else:
                            prev_op = in_var.op
                        # if it's one of inputs
571 572 573 574
                        if (
                            prev_op in fp32_op_set
                            or prev_op.type in amp_lists.fp32_list
                        ):
575
                            is_fp32_op = True
576 577 578 579
                        elif (
                            prev_op in bf16_op_set
                            or prev_op.type in amp_lists.bf16_list
                        ):
580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
                            is_bf16_op = True
            if is_fp32_op:
                fp32_op_set.add(op)
            elif is_bf16_op:
                bf16_op_set.add(op)
            else:
                pass
        else:
            # For numerical safe, we apply fp32 computation on ops that
            # are not determined which list they should stay.
            fp32_op_set.add(op)

    idx = 0
    while idx < len(ops):
        op = ops[idx]
        num_cast_ops = 0
        if op in fp32_op_set:
597 598 599 600 601 602 603
            num_cast_ops = _insert_cast_op(
                block,
                op,
                idx,
                core.VarDesc.VarType.BF16,
                core.VarDesc.VarType.FP32,
            )
604 605 606 607
        elif op in bf16_op_set:
            if op.has_attr('use_mkldnn'):
                op._set_attr('use_mkldnn', True)
                op._set_attr('mkldnn_data_type', 'bfloat16')
608 609 610 611
            elif (
                op.has_attr('dtype')
                and op.attr('dtype') == core.VarDesc.VarType.FP32
            ):
612 613
                op._set_attr('dtype', core.VarDesc.VarType.BF16)

614 615 616 617 618 619 620
            num_cast_ops = _insert_cast_op(
                block,
                op,
                idx,
                core.VarDesc.VarType.FP32,
                core.VarDesc.VarType.BF16,
            )
621 622 623 624
        else:
            pass

        idx += num_cast_ops + 1