auto_parallel_fp16.py 32.0 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict

import paddle
from paddle.framework import core
19
from paddle.fluid.framework import default_main_program, default_startup_program
20 21 22
from paddle.fluid import unique_name
from .pass_base import register_pass
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
from paddle.distributed.auto_parallel.utils import (
    set_var_dist_attr,
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
)
from paddle.distributed.auto_parallel.process_group import (
    get_world_process_group,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
    AutoMixedPrecisionLists,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
    _keep_layer_norm_scale_bias_to_fp32,
    _need_keep_fp32,
    _valid_types,
    _dtype_to_str,
)
from paddle.distributed.auto_parallel.dist_attribute import (
    OperatorDistributedAttribute,
)
from paddle.distributed.auto_parallel.utils import (
    is_forward_op,
    is_backward_op,
    OP_ROLE_KEY,
    OpRole,
)
48 49 50 51 52 53 54 55 56 57 58 59 60
from .auto_parallel_amp import AMPPass

world_process_group = get_world_process_group()
# if user use python "+, -, * /" for network, there might be cast in vanilla program
__amp_skip_ops__ = [
    'create_py_reader',
    'create_double_buffer_reader',
    'while',
    'cast',
]


def set_op_dtype_to_fp16(op):
61 62 63 64
    if (
        op.has_attr('in_dtype')
        and op.attr('in_dtype') == core.VarDesc.VarType.FP32
    ):
65
        op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
66 67 68 69
    if (
        op.has_attr('out_dtype')
        and op.attr('out_dtype') == core.VarDesc.VarType.FP32
    ):
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
    if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
        op._set_attr('dtype', core.VarDesc.VarType.FP16)


# adapot for backward op
def _keep_fp32_input(op, in_name):
    op_type = op.type
    if op_type == 'batch_norm':
        # Scale, Bias, Mean, Variance should be float32.
        return in_name != 'X'
    if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
        return in_name != 'X'
    if op_type == 'fused_bn_add_activation':
        return in_name not in {'X', 'Z'}
    if op_type == 'resnet_unit':
        return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
    if op_type in ['fused_attention', 'fused_feedforward']:
        return in_name in {
89 90 91 92 93 94
            'LnScale',
            'LnBias',
            'Ln2Scale',
            'Ln2Bias',
            "Ln1Scale",
            "Ln1Bias",
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        }
    # backward
    if op_type in ['batch_norm_grad']:
        return in_name not in {'X', 'Y@GRAD'}
    if op_type in ['layer_norm_grad']:
        return in_name not in {'X', 'Y@GRAD'}
    return False


def _keep_fp32_output(op, out_name):
    op_type = op.type
    if op_type in ['batch_norm', 'fused_bn_add_activation']:
        return out_name != 'Y'
    if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
        return out_name != 'Y'
    if op_type == 'resnet_unit':
        return out_name not in {'Y', 'ConvX', 'ConvZ'}
    if op_type in ['fused_attention', 'fused_feedforward']:
        return out_name in {
114 115 116 117 118 119
            'LnMean',
            'LnVariance',
            'Ln2Mean',
            'Ln2Variance',
            'Ln1Mean',
            'Ln1Variance',
120 121 122 123 124 125 126 127 128
        }
    # backward
    if op_type in ['layer_norm_grad']:
        return out_name != 'X@GRAD'
    if op_type in ['batch_norm_grad']:
        return out_name != 'X@GRAD'
    return False


129
class FP16State:
130 131 132 133 134 135 136 137
    def __init__(
        self,
        program,
        amp_list,
        dist_context,
        use_fp16_guard,
        input_data_var_names=None,
    ):
138 139 140 141
        self.program = program
        self.amp_list = amp_list
        self.use_fp16_guard = use_fp16_guard
        self.dist_context = dist_context
142 143 144
        self.grad_op_to_op_map = (
            self.dist_context.dist_op_context.grad_op_id_to_op_id
        )
145 146 147 148
        if input_data_var_names:
            self.input_data_var_names = input_data_var_names
        else:
            self.input_data_var_names = []
149 150 151
        self._op_fp16_dict = (
            {}
        )  # op_id --> True/False. 'True' means that the op is should run in fp16 mode.
152 153 154 155 156 157 158
        # a trick to determine leaf tensor node in program {varname: generator_op_id}
        self.forward_non_leaf_tensors = {}
        # record the cast ops that are inserted for a forward
        self.forward_input_cast_ops = defaultdict(
            list
        )  # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
        self.is_train = False
159
        self.out_var_op_deps = {}
160 161 162 163 164 165

    def _is_fp16_op(self, op_id):
        return self._op_fp16_dict.get(op_id, None)

    def _build_state(self):
        """
166
        mark the execution mode (fp16 or fp32) for ops in all blocks
167 168 169 170 171 172
        include forward ops & backward ops
        """
        # mark op dtype
        # assume all backward block are behind forward blocks
        for block in self.program.blocks:
            for op in block.ops:
173 174 175 176 177 178 179 180
                for name in op.output_arg_names:
                    if name not in self.out_var_op_deps:
                        self.out_var_op_deps[name] = [op.desc.original_id()]
                    else:
                        self.out_var_op_deps[name].extend(
                            [op.desc.original_id()]
                        )

181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
                self._mark_op(op)

        # set forward tensor dtype
        for block in self.program.blocks:
            self.resolute_tensor_dtype(block)

        # insert cast ops
        for block in self.program.blocks:
            self.cast_block(block)

        return self.is_train

    def _mark_op(self, op):

        if op.type in __amp_skip_ops__:
            return

        if is_forward_op(op):

            # ernie inference trick
            if op.type == "assign" and "array_" in op.input_arg_names[0]:
202
                self._op_fp16_dict[op.desc.original_id()] = False
203
                return
204 205 206 207 208 209 210 211 212 213 214 215
            # If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
            if op.type == "assign":
                out_name = op.output_arg_names[0]
                if len(self.out_var_op_deps[out_name]) > 1:
                    if not self._op_fp16_dict[
                        self.out_var_op_deps[out_name][0]
                    ]:
                        self._op_fp16_dict[op.desc.original_id()] = False
                    else:
                        self._op_fp16_dict[op.desc.original_id()] = True
                    return

216 217 218
            if _need_keep_fp32(
                op, self.amp_list.unsupported_list, self.use_fp16_guard
            ):
219
                self._op_fp16_dict[op.desc.original_id()] = False
220
            else:
221
                self._op_fp16_dict[op.desc.original_id()] = True
222 223 224 225 226 227
            for var_name in op.output_arg_names:
                # assert var_name not in self.forward_non_leaf_tensors, "{}".format(var_name)
                self.forward_non_leaf_tensors[var_name] = op.desc.id()

        elif is_backward_op(op) == int(OpRole.Backward):

228 229
            if op.desc.original_id() in self.grad_op_to_op_map:
                fwd_op_id = self.grad_op_to_op_map[op.desc.original_id()]
230
                assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op))
231 232 233
                self._op_fp16_dict[op.desc.original_id()] = self._op_fp16_dict[
                    fwd_op_id
                ]
234 235 236 237 238 239 240 241 242

        if int(op.attr('op_role')) == 257:
            self.is_train = True

    def set_var_to_fp16(self, var_name, block):
        var = None
        try:
            var = block.var(var_name)
        except ValueError as e:
243 244
            var = block._var_recursive(var_name)
            # var = self.program.global_block().var(var_name)
245

246
        # NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
247 248 249 250 251 252 253 254 255 256 257 258
        # a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
        if var is None or var.type not in _valid_types or "array_" in var_name:
            return

        if var.dtype == core.VarDesc.VarType.FP32:
            var.desc.set_dtype(core.VarDesc.VarType.FP16)

    def resolute_tensor_dtype(self, block):

        for op in block.ops:
            if is_forward_op(op):
                # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
259
                if self._is_fp16_op(op.desc.original_id()) or op.type == "cast":
260 261 262 263
                    for in_name in op.input_names:
                        if _keep_fp32_input(op, in_name):
                            continue
                        for in_var_name in op.input(in_name):
264 265 266 267
                            if (
                                in_var_name not in self.forward_non_leaf_tensors
                                and in_var_name not in self.input_data_var_names
                            ):
268 269 270 271 272 273 274 275
                                self.set_var_to_fp16(in_var_name, block)
                    for out_name in op.output_names:
                        if _keep_fp32_output(op, out_name):
                            continue
                        for out_var_name in op.output(out_name):
                            self.set_var_to_fp16(out_var_name, block)
                    set_op_dtype_to_fp16(op)
                # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
276
                elif not self._is_fp16_op(op.desc.original_id()):
277 278 279 280 281 282 283
                    for out_var_name in op.output_arg_names:
                        out_var = block.vars.get(out_var_name)
                        if out_var is None or out_var.type not in _valid_types:
                            continue
                        if out_var.dtype == core.VarDesc.VarType.FP16:
                            out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
            elif is_backward_op(op):
284
                if self._is_fp16_op(op.desc.original_id()):
285 286 287 288 289 290 291
                    for out_name in op.output_names:
                        if _keep_fp32_output(op, out_name):
                            continue
                        for out_var_name in op.output(out_name):
                            self.set_var_to_fp16(out_var_name, block)
                    set_op_dtype_to_fp16(op)
                # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
292
                elif not self._is_fp16_op(op.desc.original_id()):
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
                    for out_var_name in op.output_arg_names:
                        out_var = block.vars.get(out_var_name)
                        if out_var is None or out_var.type not in _valid_types:
                            continue
                        if out_var.dtype == core.VarDesc.VarType.FP16:
                            out_var.desc.set_dtype(core.VarDesc.VarType.FP32)

    def cast_block(self, block):
        dist_op_context = self.dist_context.dist_op_context
        idx = 0
        while idx < len(block.ops):
            op = block.ops[idx]
            num_cast_ops = 0

            if op.type in __amp_skip_ops__:
                idx += 1
                continue
            elif is_forward_op(op):
311
                if not self._is_fp16_op(op.desc.original_id()):
312
                    num_cast_ops = self._insert_forward_cast_ops(
313 314 315 316 317 318 319
                        op,
                        idx,
                        block,
                        core.VarDesc.VarType.FP16,
                        core.VarDesc.VarType.FP32,
                        self.dist_context,
                    )
320
                elif self._is_fp16_op(op.desc.original_id()):
321
                    num_cast_ops = self._insert_forward_cast_ops(
322 323 324 325 326 327 328
                        op,
                        idx,
                        block,
                        core.VarDesc.VarType.FP32,
                        core.VarDesc.VarType.FP16,
                        self.dist_context,
                    )
329
            elif is_backward_op(op):
330
                if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
331
                    if not self._is_fp16_op(op.desc.original_id()):
332
                        num_cast_ops = self._insert_backward_cast_ops(
333 334 335 336 337 338 339
                            op,
                            idx,
                            block,
                            core.VarDesc.VarType.FP16,
                            core.VarDesc.VarType.FP32,
                            self.dist_context,
                        )
340
                    elif self._is_fp16_op(op.desc.original_id()):
341
                        num_cast_ops = self._insert_backward_cast_ops(
342 343 344 345 346 347 348
                            op,
                            idx,
                            block,
                            core.VarDesc.VarType.FP32,
                            core.VarDesc.VarType.FP16,
                            self.dist_context,
                        )
349 350 351 352 353 354 355
                elif op.type == "sum":
                    # all inputs dtype of sum should be equal and output dtype should follow input
                    out_var_name = op.output_arg_names[0]
                    in_var_name = op.input_arg_names[0]
                    out_var = block.var(out_var_name)
                    in_var = block._find_var_recursive(in_var_name)
                    for in_var_name in op.input_arg_names:
356 357 358 359 360
                        assert (
                            in_var.dtype == block.var(in_var_name).dtype
                        ), "{}, {}, {}".format(
                            in_var, block.var(in_var_name), str(op)
                        )
361 362 363 364 365
                    out_var.desc.set_dtype(in_var.dtype)

            idx += num_cast_ops + 1
        block._sync_with_cpp()

366 367 368
    def _insert_forward_cast_ops(
        self, op, idx, block, src_dtype, dst_dtype, dist_context
    ):
369 370 371 372 373

        num_cast_ops = 0

        for in_name in op.input_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
374 375
                op, in_name
            ):
376 377 378 379 380 381
                continue

            consume_op_attr = dist_context.get_op_dist_attr_for_program(op)
            assert consume_op_attr is not None
            for in_var_name in op.input(in_name):
                in_var = block._find_var_recursive(in_var_name)
382 383 384 385 386
                if (
                    in_var is None
                    or in_var.type not in _valid_types
                    or in_var.dtype == dst_dtype
                ):
387 388 389
                    continue

                if in_var.dtype == src_dtype:
390 391 392
                    cast_name = (
                        in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
                    )
393
                    cast_var = block.vars.get(cast_name)
394 395 396
                    self.forward_input_cast_ops[op.desc.original_id()] += [
                        (cast_name, in_var.name, dst_dtype, src_dtype, in_name)
                    ]
397 398

                    in_var_dist_attr = consume_op_attr.get_input_dist_attr(
399 400
                        in_var.name
                    )
401
                    assert in_var_dist_attr is not None
402
                    # truly insert cast op
403 404 405 406 407 408 409 410 411 412 413
                    if cast_var is None or cast_var.dtype != dst_dtype:
                        # NOTE we make the cast op and var's dist attr as the op that consume the
                        # cast var instead of the op which generates the var
                        # refine op's dist_attr
                        ref_mesh = in_var_dist_attr.process_mesh
                        ref_mapping = in_var_dist_attr.dims_mapping

                        cast_var = block.create_var(
                            name=cast_name,
                            dtype=dst_dtype,
                            persistable=False,
414 415 416 417 418
                            stop_gradient=in_var.stop_gradient,
                        )
                        set_var_dist_attr(
                            dist_context, cast_var, ref_mapping, ref_mesh
                        )
419 420 421 422 423 424 425 426 427

                        cast_op = block._insert_op_without_sync(
                            idx,
                            type="cast",
                            inputs={"X": in_var},
                            outputs={"Out": cast_var},
                            attrs={
                                "in_dtype": in_var.dtype,
                                "out_dtype": cast_var.dtype,
428 429 430
                                OP_ROLE_KEY: OpRole.Forward,
                            },
                        )
431
                        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
432 433
                            cast_op, ref_mesh, ref_mapping, dist_context
                        )
434 435 436
                        num_cast_ops += 1

                    op._rename_input(in_var.name, cast_name)
437 438 439
                    consume_op_attr.set_input_dist_attr(
                        cast_name, in_var_dist_attr
                    )
440 441 442 443 444 445

        if op.has_attr('out_dtype') and op.attr('out_dtype') != -1:
            assert op.attr('out_dtype') == dst_dtype

        return num_cast_ops

446 447 448
    def _insert_backward_cast_ops(
        self, op, idx, block, src_dtype, dst_dtype, dist_context
    ):
449 450 451

        num_cast_ops = 0
        op_id = op.desc.id()
452
        original_id = op.desc.original_id()
453
        dist_op_context = dist_context.dist_op_context
454
        forward_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
455 456 457 458 459 460 461 462 463

        grad_op_attr = dist_context.get_op_dist_attr_for_program(op)
        assert grad_op_attr is not None

        for out_var_name in op.output_arg_names:
            out_var = block.var(out_var_name)
            if _keep_fp32_output(op, out_var.name):
                continue
            assert out_var.dtype == dst_dtype, "{}, {}".format(
464 465
                str(out_var), dst_dtype
            )
466

467 468 469 470 471 472 473
        for (
            cast_name,
            src_name,
            dst_dtype,
            src_dtype,
            slot_name,
        ) in self.forward_input_cast_ops[forward_op_id]:
474

475 476 477 478
            # some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
            if slot_name not in op.input_names:
                continue

479 480
            # rename input
            assert src_name in op.input(
481 482
                slot_name
            ), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op))
483 484 485 486 487 488 489
            src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
            assert src_var_dist_attr is not None
            op._rename_input(src_name, cast_name)
            grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr)

            # create cast grad
            grad_slot_name = slot_name + "@GRAD"
490 491 492
            assert (
                grad_slot_name in op.output_names
            ), "[{}], Current Op: {}".format(grad_slot_name, str(op))
493 494

            # some forward input maybe stop_gradient=True, e.g. input_mask
495 496
            if len(op.output(grad_slot_name)) == 0:
                continue
497 498 499
            assert (
                len(op.output(grad_slot_name)) == 1
            ), "[{}], Current Op: {}".format(grad_slot_name, str(op))
500 501 502 503 504 505 506 507
            grad_name = op.output(grad_slot_name)[0]
            grad = block.var(grad_name)
            grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name)
            assert grad_dist_attr is not None, "{}".format(grad_name)
            ref_mesh = grad_dist_attr.process_mesh
            ref_mapping = grad_dist_attr.dims_mapping

            cast_grad = block.create_var(
508 509 510
                name=unique_name.generate_with_ignorable_key(
                    "".join([cast_name, '@GRAD'])
                ),
511 512 513 514
                dtype=dst_dtype,
                shape=grad.shape,
                type=grad.type,
                persistable=grad.persistable,
515 516
                stop_gradient=grad.stop_gradient,
            )
517
            dist_context.set_tensor_dist_attr_for_program(
518 519
                cast_grad, grad_dist_attr
            )
520 521 522 523 524 525 526 527 528 529 530 531
            op._rename_output(grad_name, cast_grad.name)
            grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr)

            # add cast
            cast_op = block._insert_op_without_sync(
                idx + 1,
                type="cast",
                inputs={"X": [cast_grad.name]},
                outputs={"Out": [grad.name]},
                attrs={
                    "in_dtype": dst_dtype,
                    "out_dtype": src_dtype,
532 533 534
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
535 536 537
            grad.desc.set_dtype(src_dtype)

            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
538 539
                cast_op, ref_mesh, ref_mapping, dist_context
            )
540 541 542 543 544 545 546 547 548 549 550 551
            num_cast_ops += 1

        return num_cast_ops


def _check_and_update_gradient(grads, loss_scaling, name, dist_context):

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
    for e in grads:
552 553 554 555 556 557
        check_variable_and_dtype(
            e,
            "x",
            ['float16', 'float32', 'float64'],
            'check_finite_and_unscale',
        )
558 559

    found_inf = main_block.create_var(
560 561 562
        name=unique_name.generate_with_ignorable_key(
            ".".join(['find_infinite_scale', name])
        ),
563 564 565 566
        shape=[1],
        dtype='bool',
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
567 568
        stop_gradient=False,
    )
569 570 571 572
    set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)

    inputs = {'X': grads, 'Scale': loss_scaling}
    outputs = {'Out': grads, 'FoundInfinite': found_inf}
573
    attrs = {'op_role': OpRole.Optimize}
574 575 576 577 578 579
    new_op = main_block.append_op(
        type='check_finite_and_unscale',
        inputs=inputs,
        outputs=outputs,
        attrs=attrs,
    )
580 581 582 583 584 585 586 587 588

    new_op_dist_attr = OperatorDistributedAttribute()
    new_op_dist_attr.process_mesh = world_process_group.ranks
    new_op_dist_attr.impl_idx = 0
    if len(world_process_group.ranks) > 1:
        new_op_dist_attr.impl_type = "check_finite_and_unscale"
    for g in grads:
        g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
        assert g_dist_attr is not None
589 590 591 592 593 594
        new_op_dist_attr.set_input_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
        new_op_dist_attr.set_output_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
595 596 597 598 599 600 601 602
    dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
    return grads, found_inf


def _split_grads(params_grads):
    grads = [g for _, g in params_grads]
    fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
    fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
603 604 605
    assert len(fp32_grads) + len(fp16_grads) == len(
        grads
    ), "Data types of all grads must be either fp16 or fp32."
606 607 608 609 610 611 612 613 614 615 616
    return grads, fp32_grads, fp16_grads


def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context):
    new_op_dist_attr = OperatorDistributedAttribute()
    new_op_dist_attr.process_mesh = ranks
    new_op_dist_attr.impl_idx = 0
    for var_name in new_op.input_arg_names:
        var = block.var(var_name)
        var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        assert var_dist_attr is not None
617 618 619
        new_op_dist_attr.set_input_dims_mapping(
            var_name, var_dist_attr.dims_mapping
        )
620 621 622 623
    for var_name in new_op.output_arg_names:
        var = block.var(var_name)
        var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        assert var_dist_attr is not None
624 625 626
        new_op_dist_attr.set_output_dims_mapping(
            var_name, var_dist_attr.dims_mapping
        )
627 628 629
    dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)


630 631 632
def _get_memcopy_idx(block, found_inf_var):
    # use reduce_any op for check_nan_inf as the anchor for now
    for idx, op in enumerate(block.ops):
633 634 635 636
        if (
            op.type == 'reduce_any'
            and op.output_arg_names[0] == found_inf_var.name
        ):
637 638 639
            return idx + 1

    raise RuntimeError(
640 641
        "not found the correct location for memcopy for found_inf_var."
    )
642 643 644 645


def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
    src_name = src_var.name
646 647 648 649 650 651 652 653 654 655
    output_var = block.create_var(
        name=unique_name.generate_with_ignorable_key(
            src_name.join(['memcopy_'])
        ),
        dtype=src_var.dtype,
        shape=src_var.shape,
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
        stop_gradient=src_var.stop_gradient,
    )
656 657 658 659 660 661 662 663 664 665

    set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks)

    # TODO to support CUDAPinned/NPU/XPU Places
    if direction == "D2H":
        dst_place_type = 0
    elif direction == "D2H":
        dst_place_type = 1
    else:
        raise NotImplementedError(
666 667
            "direction [{}] is not supported yet.".format(direction)
        )
668 669

    attrs = {'dst_place_type': dst_place_type}
670 671 672 673 674 675 676 677 678 679
    new_op = block._insert_op_without_sync(
        index=idx,
        type='memcpy',
        inputs={'X': [src_var]},
        outputs={'Out': [output_var]},
        attrs=attrs,
    )
    _set_op_dist_attr_with_ranks(
        new_op, world_process_group.ranks, block, dist_context
    )
680 681 682 683
    block._sync_with_cpp()
    return output_var


684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706
def cast_startup_program():
    main_program = default_main_program()
    startup_program = default_startup_program()

    param_to_dtype = {}
    for block in main_program.blocks:
        for p in block.all_parameters():
            param_to_dtype[p.name] = p.dtype

    def is_initialization_op(op):
        comm_op_prefix = "c_"
        op_type = op.type
        if op_type.startswith(comm_op_prefix):
            return False

        if len(op.output_arg_names) != 1 and len(op.input_arg_names) != 0:
            return False

        return True

    for op in startup_program.global_block().ops:
        if is_initialization_op(op):
            output_name = op.output_arg_names[0]
707 708 709 710
            if (
                param_to_dtype.get(output_name, None)
                == core.VarDesc.VarType.FP16
            ):
711 712 713
                assert op.has_attr(
                    'dtype'
                ), "initialization op is supported to has dtype attribute but got {}.".format(
714 715
                    str(op)
                )
716 717 718 719
                if op.attr('dtype') == core.VarDesc.VarType.FP32:
                    op._set_attr('dtype', core.VarDesc.VarType.FP16)


720 721 722
@register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass):
    def __init__(self):
723
        super().__init__()
724

725 726
    # NOTE: why FP16Pass can override apply_single_impl instead of
    # apply_impl? AMP is an optimization pass for serial program,
727 728 729 730 731 732 733
    # in distributed scenario, all ranks should have the same modification.
    def _apply_single_impl(self, main_program, startup_program, context):
        self.dist_context = self.get_attr("dist_context")
        params_grads = self.get_attr("params_grads")

        amp_list = AutoMixedPrecisionLists(
            set(self.get_attr("custom_white_list")),
734 735 736
            set(self.get_attr("custom_black_list")),
            None,
        )
737

738
        # NOTE don't not change input data dtype, since it is controled by dataloader
739 740 741
        # and which is out of control of FP16 Pass
        input_data_var_names = [var.name for var in self.get_attr("input_data")]

742
        with paddle.static.program_guard(main_program, startup_program):
743 744 745 746 747 748 749
            fp16_state = FP16State(
                main_program,
                amp_list,
                self.dist_context,
                self.get_attr("use_fp16_guard"),
                input_data_var_names,
            )
750 751
            is_train = fp16_state._build_state()

752 753
            cast_startup_program()

754 755
        if is_train:
            with paddle.static.program_guard(main_program, startup_program):
756
                # TODO (JZ-LIANG)support cast forward program only when inference
757 758 759 760 761
                self._init_amp_var()
                self._scale_loss()

                grads, fp32_grads, fp16_grads = _split_grads(params_grads)

762 763 764 765
                if (
                    self.get_attr("use_dynamic_loss_scaling")
                    or self.get_attr("init_loss_scaling") != 1.0
                ):
766 767
                    found_infs = []
                    if fp32_grads:
768
                        with main_program._optimized_guard([]):
769
                            _, found_inf_fp32 = _check_and_update_gradient(
770 771 772 773 774
                                fp32_grads,
                                self._loss_scaling,
                                "@fp32",
                                self.dist_context,
                            )
775 776
                        found_infs.append(found_inf_fp32)
                    if fp16_grads:
777
                        with main_program._optimized_guard([]):
778
                            _, found_inf_fp16 = _check_and_update_gradient(
779 780 781 782 783
                                fp16_grads,
                                self._loss_scaling,
                                "@fp16",
                                self.dist_context,
                            )
784
                        found_infs.append(found_inf_fp16)
785
                    with main_program._optimized_guard([]):
786 787 788
                        block = main_program.global_block()

                        all_infs = paddle.fluid.layers.concat(found_infs)
789 790 791 792 793 794
                        set_var_dist_attr(
                            self.dist_context,
                            all_infs,
                            [-1],
                            world_process_group.ranks,
                        )
795 796
                        new_op = block.ops[-1]
                        assert new_op.type == "concat"
797 798 799 800 801 802
                        _set_op_dist_attr_with_ranks(
                            new_op,
                            world_process_group.ranks,
                            block,
                            self.dist_context,
                        )
803 804

                        found_inf = paddle.fluid.layers.reduce_any(all_infs)
805 806 807 808 809 810
                        set_var_dist_attr(
                            self.dist_context,
                            found_inf,
                            [-1],
                            world_process_group.ranks,
                        )
811 812
                        new_op = block.ops[-1]
                        assert new_op.type == "reduce_any"
813 814 815 816 817 818
                        _set_op_dist_attr_with_ranks(
                            new_op,
                            world_process_group.ranks,
                            block,
                            self.dist_context,
                        )
819 820

                if self.get_attr("use_dynamic_loss_scaling"):
821
                    with main_program._optimized_guard([]):
822 823 824 825 826 827 828 829 830 831
                        if fp32_grads:
                            self._update_loss_scaling(fp32_grads, found_inf)
                        if fp16_grads:
                            self._update_loss_scaling(fp16_grads, found_inf)

            # modify optimizer
            base_opt = self.get_attr("base_opt")
            base_opt._multi_precision = True
            if self.get_attr("use_optimizer_fp16"):
                base_opt._multi_precision = False
832
            if isinstance(
833 834
                base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)
            ):
835 836 837 838
                with main_program._optimized_guard([]):
                    # found_inf = paddle.tensor.creation._memcpy(
                    #     found_inf, paddle.CPUPlace())
                    insert_idx = _get_memcopy_idx(block, found_inf)
839 840 841
                    found_inf = _insert_memcopy(
                        block, insert_idx, found_inf, self.dist_context
                    )
842 843 844
                base_opt._set_auxiliary_var('found_inf', found_inf.name)
            elif hasattr(base_opt, "_set_auxiliary_var"):
                base_opt._set_auxiliary_var('found_inf', found_inf.name)