auto_parallel_fp16.py 33.9 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict

import paddle
18 19
from paddle.distributed.auto_parallel.dist_attribute import (
    OperatorDistributedAttribute,
20 21 22 23
)
from paddle.distributed.auto_parallel.process_group import (
    get_world_process_group,
)
24 25 26 27 28
from paddle.distributed.auto_parallel.utils import (
    is_backward_op,
    is_forward_op,
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
    set_var_dist_attr,
29
)
30
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
31
from paddle.fluid import unique_name
32
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
33 34
    AutoMixedPrecisionLists,
    _dtype_to_str,
35 36 37 38
    _keep_layer_norm_scale_bias_to_fp32,
    _need_keep_fp32,
    _valid_types,
)
39 40 41 42
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.framework import core

43
from .auto_parallel_amp import AMPPass
44
from .pass_base import register_pass
45 46 47 48 49 50 51 52 53 54 55 56

world_process_group = get_world_process_group()
# if user use python "+, -, * /" for network, there might be cast in vanilla program
__amp_skip_ops__ = [
    'create_py_reader',
    'create_double_buffer_reader',
    'while',
    'cast',
]


def set_op_dtype_to_fp16(op):
57 58 59 60
    if (
        op.has_attr('in_dtype')
        and op.attr('in_dtype') == core.VarDesc.VarType.FP32
    ):
61
        op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
62 63 64 65
    if (
        op.has_attr('out_dtype')
        and op.attr('out_dtype') == core.VarDesc.VarType.FP32
    ):
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
        op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
    if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
        op._set_attr('dtype', core.VarDesc.VarType.FP16)


# adapot for backward op
def _keep_fp32_input(op, in_name):
    op_type = op.type
    if op_type == 'batch_norm':
        # Scale, Bias, Mean, Variance should be float32.
        return in_name != 'X'
    if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
        return in_name != 'X'
    if op_type == 'fused_bn_add_activation':
        return in_name not in {'X', 'Z'}
    if op_type == 'resnet_unit':
        return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
    if op_type in ['fused_attention', 'fused_feedforward']:
        return in_name in {
85 86 87 88 89 90
            'LnScale',
            'LnBias',
            'Ln2Scale',
            'Ln2Bias',
            "Ln1Scale",
            "Ln1Bias",
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        }
    # backward
    if op_type in ['batch_norm_grad']:
        return in_name not in {'X', 'Y@GRAD'}
    if op_type in ['layer_norm_grad']:
        return in_name not in {'X', 'Y@GRAD'}
    return False


def _keep_fp32_output(op, out_name):
    op_type = op.type
    if op_type in ['batch_norm', 'fused_bn_add_activation']:
        return out_name != 'Y'
    if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
        return out_name != 'Y'
    if op_type == 'resnet_unit':
        return out_name not in {'Y', 'ConvX', 'ConvZ'}
    if op_type in ['fused_attention', 'fused_feedforward']:
        return out_name in {
110 111 112 113 114 115
            'LnMean',
            'LnVariance',
            'Ln2Mean',
            'Ln2Variance',
            'Ln1Mean',
            'Ln1Variance',
116 117 118 119 120 121 122 123 124
        }
    # backward
    if op_type in ['layer_norm_grad']:
        return out_name != 'X@GRAD'
    if op_type in ['batch_norm_grad']:
        return out_name != 'X@GRAD'
    return False


125
class FP16State:
126 127 128 129 130 131 132 133
    def __init__(
        self,
        program,
        amp_list,
        dist_context,
        use_fp16_guard,
        input_data_var_names=None,
    ):
134 135 136 137
        self.program = program
        self.amp_list = amp_list
        self.use_fp16_guard = use_fp16_guard
        self.dist_context = dist_context
138 139 140
        self.grad_op_to_op_map = (
            self.dist_context.dist_op_context.grad_op_id_to_op_id
        )
141 142 143 144
        if input_data_var_names:
            self.input_data_var_names = input_data_var_names
        else:
            self.input_data_var_names = []
145 146 147
        self._op_fp16_dict = (
            {}
        )  # op_id --> True/False. 'True' means that the op is should run in fp16 mode.
148 149 150 151 152 153 154
        # a trick to determine leaf tensor node in program {varname: generator_op_id}
        self.forward_non_leaf_tensors = {}
        # record the cast ops that are inserted for a forward
        self.forward_input_cast_ops = defaultdict(
            list
        )  # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
        self.is_train = False
155
        self.out_var_op_deps = {}
156 157 158 159 160 161

    def _is_fp16_op(self, op_id):
        return self._op_fp16_dict.get(op_id, None)

    def _build_state(self):
        """
162
        mark the execution mode (fp16 or fp32) for ops in all blocks
163 164 165 166 167 168
        include forward ops & backward ops
        """
        # mark op dtype
        # assume all backward block are behind forward blocks
        for block in self.program.blocks:
            for op in block.ops:
169 170 171 172 173 174 175 176
                for name in op.output_arg_names:
                    if name not in self.out_var_op_deps:
                        self.out_var_op_deps[name] = [op.desc.original_id()]
                    else:
                        self.out_var_op_deps[name].extend(
                            [op.desc.original_id()]
                        )

177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
                self._mark_op(op)

        # set forward tensor dtype
        for block in self.program.blocks:
            self.resolute_tensor_dtype(block)

        # insert cast ops
        for block in self.program.blocks:
            self.cast_block(block)

        return self.is_train

    def _mark_op(self, op):

        if op.type in __amp_skip_ops__:
            return

        if is_forward_op(op):

            # ernie inference trick
            if op.type == "assign" and "array_" in op.input_arg_names[0]:
198
                self._op_fp16_dict[op.desc.original_id()] = False
199
                return
200 201 202 203 204 205 206 207 208 209 210 211
            # If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
            if op.type == "assign":
                out_name = op.output_arg_names[0]
                if len(self.out_var_op_deps[out_name]) > 1:
                    if not self._op_fp16_dict[
                        self.out_var_op_deps[out_name][0]
                    ]:
                        self._op_fp16_dict[op.desc.original_id()] = False
                    else:
                        self._op_fp16_dict[op.desc.original_id()] = True
                    return

212 213 214
            if _need_keep_fp32(
                op, self.amp_list.unsupported_list, self.use_fp16_guard
            ):
215
                self._op_fp16_dict[op.desc.original_id()] = False
216
            else:
217
                self._op_fp16_dict[op.desc.original_id()] = True
218 219 220 221 222 223
            for var_name in op.output_arg_names:
                # assert var_name not in self.forward_non_leaf_tensors, "{}".format(var_name)
                self.forward_non_leaf_tensors[var_name] = op.desc.id()

        elif is_backward_op(op) == int(OpRole.Backward):

224 225
            if op.desc.original_id() in self.grad_op_to_op_map:
                fwd_op_id = self.grad_op_to_op_map[op.desc.original_id()]
226
                assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op))
227 228 229
                self._op_fp16_dict[op.desc.original_id()] = self._op_fp16_dict[
                    fwd_op_id
                ]
230 231 232 233 234 235 236 237 238

        if int(op.attr('op_role')) == 257:
            self.is_train = True

    def set_var_to_fp16(self, var_name, block):
        var = None
        try:
            var = block.var(var_name)
        except ValueError as e:
239 240
            var = block._var_recursive(var_name)
            # var = self.program.global_block().var(var_name)
241

242
        # NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
243 244 245 246 247 248 249 250 251 252 253 254
        # a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
        if var is None or var.type not in _valid_types or "array_" in var_name:
            return

        if var.dtype == core.VarDesc.VarType.FP32:
            var.desc.set_dtype(core.VarDesc.VarType.FP16)

    def resolute_tensor_dtype(self, block):

        for op in block.ops:
            if is_forward_op(op):
                # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
255 256 257 258
                if (
                    self._is_fp16_op(op.desc.original_id()) is True
                    or op.type == "cast"
                ):
259 260 261 262
                    for in_name in op.input_names:
                        if _keep_fp32_input(op, in_name):
                            continue
                        for in_var_name in op.input(in_name):
263 264 265 266
                            if (
                                in_var_name not in self.forward_non_leaf_tensors
                                and in_var_name not in self.input_data_var_names
                            ):
267 268 269 270 271 272 273 274
                                self.set_var_to_fp16(in_var_name, block)
                    for out_name in op.output_names:
                        if _keep_fp32_output(op, out_name):
                            continue
                        for out_var_name in op.output(out_name):
                            self.set_var_to_fp16(out_var_name, block)
                    set_op_dtype_to_fp16(op)
                # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
275
                elif self._is_fp16_op(op.desc.original_id()) is False:
276 277 278 279 280 281 282
                    for out_var_name in op.output_arg_names:
                        out_var = block.vars.get(out_var_name)
                        if out_var is None or out_var.type not in _valid_types:
                            continue
                        if out_var.dtype == core.VarDesc.VarType.FP16:
                            out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
            elif is_backward_op(op):
283
                if self._is_fp16_op(op.desc.original_id()) is True:
284 285 286 287 288 289 290
                    for out_name in op.output_names:
                        if _keep_fp32_output(op, out_name):
                            continue
                        for out_var_name in op.output(out_name):
                            self.set_var_to_fp16(out_var_name, block)
                    set_op_dtype_to_fp16(op)
                # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
291
                elif self._is_fp16_op(op.desc.original_id()) is False:
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
                    for out_var_name in op.output_arg_names:
                        out_var = block.vars.get(out_var_name)
                        if out_var is None or out_var.type not in _valid_types:
                            continue
                        if out_var.dtype == core.VarDesc.VarType.FP16:
                            out_var.desc.set_dtype(core.VarDesc.VarType.FP32)

    def cast_block(self, block):
        dist_op_context = self.dist_context.dist_op_context
        idx = 0
        while idx < len(block.ops):
            op = block.ops[idx]
            num_cast_ops = 0

            if op.type in __amp_skip_ops__:
                idx += 1
                continue
            elif is_forward_op(op):
310
                if self._is_fp16_op(op.desc.original_id()) is False:
311
                    num_cast_ops = self._insert_forward_cast_ops(
312 313 314 315 316 317 318
                        op,
                        idx,
                        block,
                        core.VarDesc.VarType.FP16,
                        core.VarDesc.VarType.FP32,
                        self.dist_context,
                    )
319
                elif self._is_fp16_op(op.desc.original_id()) is True:
320
                    num_cast_ops = self._insert_forward_cast_ops(
321 322 323 324 325 326 327
                        op,
                        idx,
                        block,
                        core.VarDesc.VarType.FP32,
                        core.VarDesc.VarType.FP16,
                        self.dist_context,
                    )
328
            elif is_backward_op(op):
329
                if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
330
                    if self._is_fp16_op(op.desc.original_id()) is False:
331
                        num_cast_ops = self._insert_backward_cast_ops(
332 333 334 335 336 337 338
                            op,
                            idx,
                            block,
                            core.VarDesc.VarType.FP16,
                            core.VarDesc.VarType.FP32,
                            self.dist_context,
                        )
339
                    elif self._is_fp16_op(op.desc.original_id()) is True:
340
                        num_cast_ops = self._insert_backward_cast_ops(
341 342 343 344 345 346 347
                            op,
                            idx,
                            block,
                            core.VarDesc.VarType.FP32,
                            core.VarDesc.VarType.FP16,
                            self.dist_context,
                        )
348 349 350 351 352 353 354
                elif op.type == "sum":
                    # all inputs dtype of sum should be equal and output dtype should follow input
                    out_var_name = op.output_arg_names[0]
                    in_var_name = op.input_arg_names[0]
                    out_var = block.var(out_var_name)
                    in_var = block._find_var_recursive(in_var_name)
                    for in_var_name in op.input_arg_names:
355 356 357 358 359
                        assert (
                            in_var.dtype == block.var(in_var_name).dtype
                        ), "{}, {}, {}".format(
                            in_var, block.var(in_var_name), str(op)
                        )
360 361 362 363 364
                    out_var.desc.set_dtype(in_var.dtype)

            idx += num_cast_ops + 1
        block._sync_with_cpp()

365 366 367
    def _insert_forward_cast_ops(
        self, op, idx, block, src_dtype, dst_dtype, dist_context
    ):
368 369 370 371 372

        num_cast_ops = 0

        for in_name in op.input_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
373 374
                op, in_name
            ):
375 376 377 378 379 380
                continue

            consume_op_attr = dist_context.get_op_dist_attr_for_program(op)
            assert consume_op_attr is not None
            for in_var_name in op.input(in_name):
                in_var = block._find_var_recursive(in_var_name)
381 382 383 384 385
                if (
                    in_var is None
                    or in_var.type not in _valid_types
                    or in_var.dtype == dst_dtype
                ):
386 387 388
                    continue

                if in_var.dtype == src_dtype:
389 390 391
                    cast_name = (
                        in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
                    )
392
                    cast_var = block.vars.get(cast_name)
393 394 395
                    self.forward_input_cast_ops[op.desc.original_id()] += [
                        (cast_name, in_var.name, dst_dtype, src_dtype, in_name)
                    ]
396 397

                    in_var_dist_attr = consume_op_attr.get_input_dist_attr(
398 399
                        in_var.name
                    )
400
                    assert in_var_dist_attr is not None
401
                    # truly insert cast op
402 403 404 405 406 407 408 409 410 411 412
                    if cast_var is None or cast_var.dtype != dst_dtype:
                        # NOTE we make the cast op and var's dist attr as the op that consume the
                        # cast var instead of the op which generates the var
                        # refine op's dist_attr
                        ref_mesh = in_var_dist_attr.process_mesh
                        ref_mapping = in_var_dist_attr.dims_mapping

                        cast_var = block.create_var(
                            name=cast_name,
                            dtype=dst_dtype,
                            persistable=False,
413 414 415 416 417
                            stop_gradient=in_var.stop_gradient,
                        )
                        set_var_dist_attr(
                            dist_context, cast_var, ref_mapping, ref_mesh
                        )
418

419 420 421
                        op_namescope = "/"
                        if op.has_attr('op_namescope'):
                            op_namescope = op.attr('op_namescope')
422 423 424 425 426 427 428 429
                        cast_op = block._insert_op_without_sync(
                            idx,
                            type="cast",
                            inputs={"X": in_var},
                            outputs={"Out": cast_var},
                            attrs={
                                "in_dtype": in_var.dtype,
                                "out_dtype": cast_var.dtype,
430 431 432
                                OP_ROLE_KEY: OpRole.Forward,
                            },
                        )
433 434 435
                        cast_op._set_attr(
                            'op_namescope', op_namescope
                        )  # for recompute
436
                        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
437 438
                            cast_op, ref_mesh, ref_mapping, dist_context
                        )
439 440 441
                        num_cast_ops += 1

                    op._rename_input(in_var.name, cast_name)
442 443 444
                    consume_op_attr.set_input_dist_attr(
                        cast_name, in_var_dist_attr
                    )
445 446 447 448 449 450

        if op.has_attr('out_dtype') and op.attr('out_dtype') != -1:
            assert op.attr('out_dtype') == dst_dtype

        return num_cast_ops

451 452 453
    def _insert_backward_cast_ops(
        self, op, idx, block, src_dtype, dst_dtype, dist_context
    ):
454 455 456

        num_cast_ops = 0
        op_id = op.desc.id()
457
        original_id = op.desc.original_id()
458
        dist_op_context = dist_context.dist_op_context
459
        forward_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
460 461 462 463 464 465 466 467 468

        grad_op_attr = dist_context.get_op_dist_attr_for_program(op)
        assert grad_op_attr is not None

        for out_var_name in op.output_arg_names:
            out_var = block.var(out_var_name)
            if _keep_fp32_output(op, out_var.name):
                continue
            assert out_var.dtype == dst_dtype, "{}, {}".format(
469 470
                str(out_var), dst_dtype
            )
471

472 473 474 475 476 477 478
        for (
            cast_name,
            src_name,
            dst_dtype,
            src_dtype,
            slot_name,
        ) in self.forward_input_cast_ops[forward_op_id]:
479

480 481 482 483
            # some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
            if slot_name not in op.input_names:
                continue

484 485
            # rename input
            assert src_name in op.input(
486 487
                slot_name
            ), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op))
488 489 490 491 492 493 494
            src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
            assert src_var_dist_attr is not None
            op._rename_input(src_name, cast_name)
            grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr)

            # create cast grad
            grad_slot_name = slot_name + "@GRAD"
495 496
            if grad_slot_name not in op.output_names:
                continue
497 498

            # some forward input maybe stop_gradient=True, e.g. input_mask
499 500
            if len(op.output(grad_slot_name)) == 0:
                continue
501 502 503
            assert (
                len(op.output(grad_slot_name)) == 1
            ), "[{}], Current Op: {}".format(grad_slot_name, str(op))
504 505 506 507 508 509 510 511
            grad_name = op.output(grad_slot_name)[0]
            grad = block.var(grad_name)
            grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name)
            assert grad_dist_attr is not None, "{}".format(grad_name)
            ref_mesh = grad_dist_attr.process_mesh
            ref_mapping = grad_dist_attr.dims_mapping

            cast_grad = block.create_var(
512 513 514
                name=unique_name.generate_with_ignorable_key(
                    "".join([cast_name, '@GRAD'])
                ),
515 516 517 518
                dtype=dst_dtype,
                shape=grad.shape,
                type=grad.type,
                persistable=grad.persistable,
519 520
                stop_gradient=grad.stop_gradient,
            )
521
            dist_context.set_tensor_dist_attr_for_program(
522 523
                cast_grad, grad_dist_attr
            )
524 525 526 527 528 529 530 531 532 533 534 535
            op._rename_output(grad_name, cast_grad.name)
            grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr)

            # add cast
            cast_op = block._insert_op_without_sync(
                idx + 1,
                type="cast",
                inputs={"X": [cast_grad.name]},
                outputs={"Out": [grad.name]},
                attrs={
                    "in_dtype": dst_dtype,
                    "out_dtype": src_dtype,
536 537 538
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
539 540 541
            grad.desc.set_dtype(src_dtype)

            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
542 543
                cast_op, ref_mesh, ref_mapping, dist_context
            )
544 545 546 547 548 549 550 551 552 553 554 555
            num_cast_ops += 1

        return num_cast_ops


def _check_and_update_gradient(grads, loss_scaling, name, dist_context):

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
    for e in grads:
556 557 558 559 560 561
        check_variable_and_dtype(
            e,
            "x",
            ['float16', 'float32', 'float64'],
            'check_finite_and_unscale',
        )
562 563

    found_inf = main_block.create_var(
564 565 566
        name=unique_name.generate_with_ignorable_key(
            ".".join(['find_infinite_scale', name])
        ),
567 568 569 570
        shape=[1],
        dtype='bool',
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
571 572
        stop_gradient=False,
    )
573 574 575 576
    set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)

    inputs = {'X': grads, 'Scale': loss_scaling}
    outputs = {'Out': grads, 'FoundInfinite': found_inf}
577
    attrs = {'op_role': OpRole.Optimize}
578 579 580 581 582 583
    new_op = main_block.append_op(
        type='check_finite_and_unscale',
        inputs=inputs,
        outputs=outputs,
        attrs=attrs,
    )
584 585 586 587 588 589 590 591 592

    new_op_dist_attr = OperatorDistributedAttribute()
    new_op_dist_attr.process_mesh = world_process_group.ranks
    new_op_dist_attr.impl_idx = 0
    if len(world_process_group.ranks) > 1:
        new_op_dist_attr.impl_type = "check_finite_and_unscale"
    for g in grads:
        g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
        assert g_dist_attr is not None
593 594 595 596 597 598
        new_op_dist_attr.set_input_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
        new_op_dist_attr.set_output_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
599 600 601 602 603 604 605 606
    dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
    return grads, found_inf


def _split_grads(params_grads):
    grads = [g for _, g in params_grads]
    fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
    fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
607 608 609
    assert len(fp32_grads) + len(fp16_grads) == len(
        grads
    ), "Data types of all grads must be either fp16 or fp32."
610 611 612 613 614 615 616 617 618 619 620
    return grads, fp32_grads, fp16_grads


def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context):
    new_op_dist_attr = OperatorDistributedAttribute()
    new_op_dist_attr.process_mesh = ranks
    new_op_dist_attr.impl_idx = 0
    for var_name in new_op.input_arg_names:
        var = block.var(var_name)
        var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        assert var_dist_attr is not None
621 622 623
        new_op_dist_attr.set_input_dims_mapping(
            var_name, var_dist_attr.dims_mapping
        )
624 625 626 627
    for var_name in new_op.output_arg_names:
        var = block.var(var_name)
        var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        assert var_dist_attr is not None
628 629 630
        new_op_dist_attr.set_output_dims_mapping(
            var_name, var_dist_attr.dims_mapping
        )
631 632 633
    dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)


634 635 636
def _get_memcopy_idx(block, found_inf_var):
    # use reduce_any op for check_nan_inf as the anchor for now
    for idx, op in enumerate(block.ops):
637 638 639 640
        if (
            op.type == 'reduce_any'
            and op.output_arg_names[0] == found_inf_var.name
        ):
641 642 643
            return idx + 1

    raise RuntimeError(
644 645
        "not found the correct location for memcopy for found_inf_var."
    )
646 647 648 649


def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
    src_name = src_var.name
650 651 652 653 654 655 656 657 658 659
    output_var = block.create_var(
        name=unique_name.generate_with_ignorable_key(
            src_name.join(['memcopy_'])
        ),
        dtype=src_var.dtype,
        shape=src_var.shape,
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
        stop_gradient=src_var.stop_gradient,
    )
660 661 662 663 664 665 666 667 668 669

    set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks)

    # TODO to support CUDAPinned/NPU/XPU Places
    if direction == "D2H":
        dst_place_type = 0
    elif direction == "D2H":
        dst_place_type = 1
    else:
        raise NotImplementedError(
670 671
            "direction [{}] is not supported yet.".format(direction)
        )
672 673

    attrs = {'dst_place_type': dst_place_type}
674 675 676 677 678 679 680 681 682 683
    new_op = block._insert_op_without_sync(
        index=idx,
        type='memcpy',
        inputs={'X': [src_var]},
        outputs={'Out': [output_var]},
        attrs=attrs,
    )
    _set_op_dist_attr_with_ranks(
        new_op, world_process_group.ranks, block, dist_context
    )
684 685 686 687
    block._sync_with_cpp()
    return output_var


688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710
def cast_startup_program():
    main_program = default_main_program()
    startup_program = default_startup_program()

    param_to_dtype = {}
    for block in main_program.blocks:
        for p in block.all_parameters():
            param_to_dtype[p.name] = p.dtype

    def is_initialization_op(op):
        comm_op_prefix = "c_"
        op_type = op.type
        if op_type.startswith(comm_op_prefix):
            return False

        if len(op.output_arg_names) != 1 and len(op.input_arg_names) != 0:
            return False

        return True

    for op in startup_program.global_block().ops:
        if is_initialization_op(op):
            output_name = op.output_arg_names[0]
711 712 713 714
            if (
                param_to_dtype.get(output_name, None)
                == core.VarDesc.VarType.FP16
            ):
715 716 717
                assert op.has_attr(
                    'dtype'
                ), "initialization op is supported to has dtype attribute but got {}.".format(
718 719
                    str(op)
                )
720 721 722 723
                if op.attr('dtype') == core.VarDesc.VarType.FP32:
                    op._set_attr('dtype', core.VarDesc.VarType.FP16)


724 725 726
@register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass):
    def __init__(self):
727
        super().__init__()
728

729 730
    # NOTE: why FP16Pass can override apply_single_impl instead of
    # apply_impl? AMP is an optimization pass for serial program,
731 732 733 734 735 736 737
    # in distributed scenario, all ranks should have the same modification.
    def _apply_single_impl(self, main_program, startup_program, context):
        self.dist_context = self.get_attr("dist_context")
        params_grads = self.get_attr("params_grads")

        amp_list = AutoMixedPrecisionLists(
            set(self.get_attr("custom_white_list")),
738 739 740
            set(self.get_attr("custom_black_list")),
            None,
        )
741

742
        # NOTE don't not change input data dtype, since it is controled by dataloader
743 744 745
        # and which is out of control of FP16 Pass
        input_data_var_names = [var.name for var in self.get_attr("input_data")]

746
        with paddle.static.program_guard(main_program, startup_program):
747 748 749 750 751 752 753
            fp16_state = FP16State(
                main_program,
                amp_list,
                self.dist_context,
                self.get_attr("use_fp16_guard"),
                input_data_var_names,
            )
754 755
            is_train = fp16_state._build_state()

756 757
            cast_startup_program()

758 759
        if is_train:
            with paddle.static.program_guard(main_program, startup_program):
760
                # TODO (JZ-LIANG)support cast forward program only when inference
761 762 763 764 765
                self._init_amp_var()
                self._scale_loss()

                grads, fp32_grads, fp16_grads = _split_grads(params_grads)

766 767 768 769
                if (
                    self.get_attr("use_dynamic_loss_scaling")
                    or self.get_attr("init_loss_scaling") != 1.0
                ):
770 771
                    found_infs = []
                    if fp32_grads:
772
                        with main_program._optimized_guard([]):
773
                            _, found_inf_fp32 = _check_and_update_gradient(
774 775 776 777 778
                                fp32_grads,
                                self._loss_scaling,
                                "@fp32",
                                self.dist_context,
                            )
779 780
                        found_infs.append(found_inf_fp32)
                    if fp16_grads:
781
                        with main_program._optimized_guard([]):
782
                            _, found_inf_fp16 = _check_and_update_gradient(
783 784 785 786 787
                                fp16_grads,
                                self._loss_scaling,
                                "@fp16",
                                self.dist_context,
                            )
788
                        found_infs.append(found_inf_fp16)
789
                    with main_program._optimized_guard([]):
790 791
                        block = main_program.global_block()

792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809
                        # all_infs = paddle.fluid.layers.concat(found_infs)
                        all_infs = block.create_var(
                            name=paddle.fluid.unique_name.generate_with_ignorable_key(
                                ".".join(['concat', 'tmp'])
                            ),
                            dtype=found_infs[0].dtype,
                            shape=None,
                            lod_level=found_infs[0].lod_level,
                            type=found_infs[0].type,
                            persistable=False,
                            stop_gradient=False,
                        )
                        concat_op = block.append_op(
                            type='concat',
                            inputs={'X': found_infs},
                            outputs={'Out': [all_infs]},
                            attrs={'axis': 0},
                        )
810 811 812 813 814 815 816
                        set_var_dist_attr(
                            self.dist_context,
                            all_infs,
                            [-1],
                            world_process_group.ranks,
                        )
                        _set_op_dist_attr_with_ranks(
817
                            concat_op,
818 819 820 821
                            world_process_group.ranks,
                            block,
                            self.dist_context,
                        )
822

823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844
                        # found_inf = paddle.fluid.layers.reduce_any(all_infs)
                        found_inf = block.create_var(
                            name=paddle.fluid.unique_name.generate_with_ignorable_key(
                                ".".join(['reduce_any', 'tmp'])
                            ),
                            dtype=all_infs.dtype,
                            shape=None,
                            lod_level=all_infs.lod_level,
                            type=all_infs.type,
                            persistable=False,
                            stop_gradient=False,
                        )
                        reduce_any_op = block.append_op(
                            type='reduce_any',
                            inputs={'X': all_infs},
                            outputs={'Out': found_inf},
                            attrs={
                                'dim': [0],
                                'keep_dim': False,
                                'reduce_all': True,
                            },
                        )
845 846 847 848 849 850 851
                        set_var_dist_attr(
                            self.dist_context,
                            found_inf,
                            [-1],
                            world_process_group.ranks,
                        )
                        _set_op_dist_attr_with_ranks(
852
                            reduce_any_op,
853 854 855 856
                            world_process_group.ranks,
                            block,
                            self.dist_context,
                        )
857 858

                if self.get_attr("use_dynamic_loss_scaling"):
859
                    with main_program._optimized_guard([]):
860 861 862 863 864 865 866 867 868 869
                        if fp32_grads:
                            self._update_loss_scaling(fp32_grads, found_inf)
                        if fp16_grads:
                            self._update_loss_scaling(fp16_grads, found_inf)

            # modify optimizer
            base_opt = self.get_attr("base_opt")
            base_opt._multi_precision = True
            if self.get_attr("use_optimizer_fp16"):
                base_opt._multi_precision = False
870
            if isinstance(
871 872
                base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)
            ):
873 874 875 876
                with main_program._optimized_guard([]):
                    # found_inf = paddle.tensor.creation._memcpy(
                    #     found_inf, paddle.CPUPlace())
                    insert_idx = _get_memcopy_idx(block, found_inf)
877 878 879
                    found_inf = _insert_memcopy(
                        block, insert_idx, found_inf, self.dist_context
                    )
880 881 882
                base_opt._set_auxiliary_var('found_inf', found_inf.name)
            elif hasattr(base_opt, "_set_auxiliary_var"):
                base_opt._set_auxiliary_var('found_inf', found_inf.name)