auto_parallel_fp16.py 34.0 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict

import paddle
18
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr
19 20 21
from paddle.distributed.auto_parallel.process_group import (
    get_world_process_group,
)
22 23 24 25 26
from paddle.distributed.auto_parallel.utils import (
    is_backward_op,
    is_forward_op,
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
    set_var_dist_attr,
27
)
28
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
29
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
30 31
    AutoMixedPrecisionLists,
    _dtype_to_str,
32 33 34 35
    _keep_layer_norm_scale_bias_to_fp32,
    _need_keep_fp32,
    _valid_types,
)
36 37
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.framework import core
38 39
from paddle.static import default_main_program, default_startup_program
from paddle.utils import unique_name
40

41
from ..auto_parallel.process_mesh import ProcessMesh
42
from .auto_parallel_amp import AMPPass
43
from .pass_base import register_pass
44 45 46 47 48 49 50 51 52 53 54 55

world_process_group = get_world_process_group()
# if user use python "+, -, * /" for network, there might be cast in vanilla program
__amp_skip_ops__ = [
    'create_py_reader',
    'create_double_buffer_reader',
    'while',
    'cast',
]


def set_op_dtype_to_fp16(op):
56 57 58 59
    if (
        op.has_attr('in_dtype')
        and op.attr('in_dtype') == core.VarDesc.VarType.FP32
    ):
60
        op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
61 62 63 64
    if (
        op.has_attr('out_dtype')
        and op.attr('out_dtype') == core.VarDesc.VarType.FP32
    ):
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
    if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
        op._set_attr('dtype', core.VarDesc.VarType.FP16)


# adapot for backward op
def _keep_fp32_input(op, in_name):
    op_type = op.type
    if op_type == 'batch_norm':
        # Scale, Bias, Mean, Variance should be float32.
        return in_name != 'X'
    if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
        return in_name != 'X'
    if op_type == 'fused_bn_add_activation':
        return in_name not in {'X', 'Z'}
    if op_type == 'resnet_unit':
        return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
    if op_type in ['fused_attention', 'fused_feedforward']:
        return in_name in {
84 85 86 87 88 89
            'LnScale',
            'LnBias',
            'Ln2Scale',
            'Ln2Bias',
            "Ln1Scale",
            "Ln1Bias",
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
        }
    # backward
    if op_type in ['batch_norm_grad']:
        return in_name not in {'X', 'Y@GRAD'}
    if op_type in ['layer_norm_grad']:
        return in_name not in {'X', 'Y@GRAD'}
    return False


def _keep_fp32_output(op, out_name):
    op_type = op.type
    if op_type in ['batch_norm', 'fused_bn_add_activation']:
        return out_name != 'Y'
    if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
        return out_name != 'Y'
    if op_type == 'resnet_unit':
        return out_name not in {'Y', 'ConvX', 'ConvZ'}
    if op_type in ['fused_attention', 'fused_feedforward']:
        return out_name in {
109 110 111 112 113 114
            'LnMean',
            'LnVariance',
            'Ln2Mean',
            'Ln2Variance',
            'Ln1Mean',
            'Ln1Variance',
115 116 117 118 119 120 121 122 123
        }
    # backward
    if op_type in ['layer_norm_grad']:
        return out_name != 'X@GRAD'
    if op_type in ['batch_norm_grad']:
        return out_name != 'X@GRAD'
    return False


124
class FP16State:
125 126 127 128 129 130 131 132
    def __init__(
        self,
        program,
        amp_list,
        dist_context,
        use_fp16_guard,
        input_data_var_names=None,
    ):
133 134 135 136
        self.program = program
        self.amp_list = amp_list
        self.use_fp16_guard = use_fp16_guard
        self.dist_context = dist_context
137 138 139
        self.grad_op_to_op_map = (
            self.dist_context.dist_op_context.grad_op_id_to_op_id
        )
140 141 142 143
        if input_data_var_names:
            self.input_data_var_names = input_data_var_names
        else:
            self.input_data_var_names = []
144 145 146
        self._op_fp16_dict = (
            {}
        )  # op_id --> True/False. 'True' means that the op is should run in fp16 mode.
147 148 149 150 151 152 153
        # a trick to determine leaf tensor node in program {varname: generator_op_id}
        self.forward_non_leaf_tensors = {}
        # record the cast ops that are inserted for a forward
        self.forward_input_cast_ops = defaultdict(
            list
        )  # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
        self.is_train = False
154
        self.out_var_op_deps = {}
155 156 157 158 159 160

    def _is_fp16_op(self, op_id):
        return self._op_fp16_dict.get(op_id, None)

    def _build_state(self):
        """
161
        mark the execution mode (fp16 or fp32) for ops in all blocks
162 163 164 165 166 167
        include forward ops & backward ops
        """
        # mark op dtype
        # assume all backward block are behind forward blocks
        for block in self.program.blocks:
            for op in block.ops:
168 169 170 171 172 173 174 175
                for name in op.output_arg_names:
                    if name not in self.out_var_op_deps:
                        self.out_var_op_deps[name] = [op.desc.original_id()]
                    else:
                        self.out_var_op_deps[name].extend(
                            [op.desc.original_id()]
                        )

176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
                self._mark_op(op)

        # set forward tensor dtype
        for block in self.program.blocks:
            self.resolute_tensor_dtype(block)

        # insert cast ops
        for block in self.program.blocks:
            self.cast_block(block)

        return self.is_train

    def _mark_op(self, op):

        if op.type in __amp_skip_ops__:
            return

        if is_forward_op(op):

            # ernie inference trick
            if op.type == "assign" and "array_" in op.input_arg_names[0]:
197
                self._op_fp16_dict[op.desc.original_id()] = False
198
                return
199 200 201 202 203 204 205 206 207 208 209 210
            # If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
            if op.type == "assign":
                out_name = op.output_arg_names[0]
                if len(self.out_var_op_deps[out_name]) > 1:
                    if not self._op_fp16_dict[
                        self.out_var_op_deps[out_name][0]
                    ]:
                        self._op_fp16_dict[op.desc.original_id()] = False
                    else:
                        self._op_fp16_dict[op.desc.original_id()] = True
                    return

211 212 213
            if _need_keep_fp32(
                op, self.amp_list.unsupported_list, self.use_fp16_guard
            ):
214
                self._op_fp16_dict[op.desc.original_id()] = False
215
            else:
216
                self._op_fp16_dict[op.desc.original_id()] = True
217 218 219 220 221 222
            for var_name in op.output_arg_names:
                # assert var_name not in self.forward_non_leaf_tensors, "{}".format(var_name)
                self.forward_non_leaf_tensors[var_name] = op.desc.id()

        elif is_backward_op(op) == int(OpRole.Backward):

223 224
            if op.desc.original_id() in self.grad_op_to_op_map:
                fwd_op_id = self.grad_op_to_op_map[op.desc.original_id()]
225
                assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op))
226 227 228
                self._op_fp16_dict[op.desc.original_id()] = self._op_fp16_dict[
                    fwd_op_id
                ]
229 230 231 232 233 234 235 236 237

        if int(op.attr('op_role')) == 257:
            self.is_train = True

    def set_var_to_fp16(self, var_name, block):
        var = None
        try:
            var = block.var(var_name)
        except ValueError as e:
238 239
            var = block._var_recursive(var_name)
            # var = self.program.global_block().var(var_name)
240

241
        # NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
242 243 244 245 246 247 248 249 250 251 252 253
        # a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
        if var is None or var.type not in _valid_types or "array_" in var_name:
            return

        if var.dtype == core.VarDesc.VarType.FP32:
            var.desc.set_dtype(core.VarDesc.VarType.FP16)

    def resolute_tensor_dtype(self, block):

        for op in block.ops:
            if is_forward_op(op):
                # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
254 255 256 257
                if (
                    self._is_fp16_op(op.desc.original_id()) is True
                    or op.type == "cast"
                ):
258 259 260 261
                    for in_name in op.input_names:
                        if _keep_fp32_input(op, in_name):
                            continue
                        for in_var_name in op.input(in_name):
262 263 264 265
                            if (
                                in_var_name not in self.forward_non_leaf_tensors
                                and in_var_name not in self.input_data_var_names
                            ):
266 267 268 269 270 271 272 273
                                self.set_var_to_fp16(in_var_name, block)
                    for out_name in op.output_names:
                        if _keep_fp32_output(op, out_name):
                            continue
                        for out_var_name in op.output(out_name):
                            self.set_var_to_fp16(out_var_name, block)
                    set_op_dtype_to_fp16(op)
                # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
274
                elif self._is_fp16_op(op.desc.original_id()) is False:
275 276 277 278 279 280 281
                    for out_var_name in op.output_arg_names:
                        out_var = block.vars.get(out_var_name)
                        if out_var is None or out_var.type not in _valid_types:
                            continue
                        if out_var.dtype == core.VarDesc.VarType.FP16:
                            out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
            elif is_backward_op(op):
282
                if self._is_fp16_op(op.desc.original_id()) is True:
283 284 285 286 287 288 289
                    for out_name in op.output_names:
                        if _keep_fp32_output(op, out_name):
                            continue
                        for out_var_name in op.output(out_name):
                            self.set_var_to_fp16(out_var_name, block)
                    set_op_dtype_to_fp16(op)
                # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
290
                elif self._is_fp16_op(op.desc.original_id()) is False:
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
                    for out_var_name in op.output_arg_names:
                        out_var = block.vars.get(out_var_name)
                        if out_var is None or out_var.type not in _valid_types:
                            continue
                        if out_var.dtype == core.VarDesc.VarType.FP16:
                            out_var.desc.set_dtype(core.VarDesc.VarType.FP32)

    def cast_block(self, block):
        dist_op_context = self.dist_context.dist_op_context
        idx = 0
        while idx < len(block.ops):
            op = block.ops[idx]
            num_cast_ops = 0

            if op.type in __amp_skip_ops__:
                idx += 1
                continue
            elif is_forward_op(op):
309
                if self._is_fp16_op(op.desc.original_id()) is False:
310
                    num_cast_ops = self._insert_forward_cast_ops(
311 312 313 314 315 316 317
                        op,
                        idx,
                        block,
                        core.VarDesc.VarType.FP16,
                        core.VarDesc.VarType.FP32,
                        self.dist_context,
                    )
318
                elif self._is_fp16_op(op.desc.original_id()) is True:
319
                    num_cast_ops = self._insert_forward_cast_ops(
320 321 322 323 324 325 326
                        op,
                        idx,
                        block,
                        core.VarDesc.VarType.FP32,
                        core.VarDesc.VarType.FP16,
                        self.dist_context,
                    )
327
            elif is_backward_op(op):
328
                if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
329
                    if self._is_fp16_op(op.desc.original_id()) is False:
330
                        num_cast_ops = self._insert_backward_cast_ops(
331 332 333 334 335 336 337
                            op,
                            idx,
                            block,
                            core.VarDesc.VarType.FP16,
                            core.VarDesc.VarType.FP32,
                            self.dist_context,
                        )
338
                    elif self._is_fp16_op(op.desc.original_id()) is True:
339
                        num_cast_ops = self._insert_backward_cast_ops(
340 341 342 343 344 345 346
                            op,
                            idx,
                            block,
                            core.VarDesc.VarType.FP32,
                            core.VarDesc.VarType.FP16,
                            self.dist_context,
                        )
347 348 349 350 351 352 353
                elif op.type == "sum":
                    # all inputs dtype of sum should be equal and output dtype should follow input
                    out_var_name = op.output_arg_names[0]
                    in_var_name = op.input_arg_names[0]
                    out_var = block.var(out_var_name)
                    in_var = block._find_var_recursive(in_var_name)
                    for in_var_name in op.input_arg_names:
354 355 356 357 358
                        assert (
                            in_var.dtype == block.var(in_var_name).dtype
                        ), "{}, {}, {}".format(
                            in_var, block.var(in_var_name), str(op)
                        )
359 360 361 362 363
                    out_var.desc.set_dtype(in_var.dtype)

            idx += num_cast_ops + 1
        block._sync_with_cpp()

364 365 366
    def _insert_forward_cast_ops(
        self, op, idx, block, src_dtype, dst_dtype, dist_context
    ):
367 368 369 370 371

        num_cast_ops = 0

        for in_name in op.input_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
372 373
                op, in_name
            ):
374 375 376 377 378 379
                continue

            consume_op_attr = dist_context.get_op_dist_attr_for_program(op)
            assert consume_op_attr is not None
            for in_var_name in op.input(in_name):
                in_var = block._find_var_recursive(in_var_name)
380 381 382 383 384
                if (
                    in_var is None
                    or in_var.type not in _valid_types
                    or in_var.dtype == dst_dtype
                ):
385 386 387
                    continue

                if in_var.dtype == src_dtype:
388 389 390
                    cast_name = (
                        in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
                    )
391
                    cast_var = block.vars.get(cast_name)
392 393 394
                    self.forward_input_cast_ops[op.desc.original_id()] += [
                        (cast_name, in_var.name, dst_dtype, src_dtype, in_name)
                    ]
395 396

                    in_var_dist_attr = consume_op_attr.get_input_dist_attr(
397 398
                        in_var.name
                    )
399
                    assert in_var_dist_attr is not None
400
                    # truly insert cast op
401 402 403 404 405 406 407 408 409 410 411
                    if cast_var is None or cast_var.dtype != dst_dtype:
                        # NOTE we make the cast op and var's dist attr as the op that consume the
                        # cast var instead of the op which generates the var
                        # refine op's dist_attr
                        ref_mesh = in_var_dist_attr.process_mesh
                        ref_mapping = in_var_dist_attr.dims_mapping

                        cast_var = block.create_var(
                            name=cast_name,
                            dtype=dst_dtype,
                            persistable=False,
412 413 414 415 416
                            stop_gradient=in_var.stop_gradient,
                        )
                        set_var_dist_attr(
                            dist_context, cast_var, ref_mapping, ref_mesh
                        )
417

418 419 420
                        op_namescope = "/"
                        if op.has_attr('op_namescope'):
                            op_namescope = op.attr('op_namescope')
421 422 423 424 425 426 427 428
                        cast_op = block._insert_op_without_sync(
                            idx,
                            type="cast",
                            inputs={"X": in_var},
                            outputs={"Out": cast_var},
                            attrs={
                                "in_dtype": in_var.dtype,
                                "out_dtype": cast_var.dtype,
429 430 431
                                OP_ROLE_KEY: OpRole.Forward,
                            },
                        )
432 433 434
                        cast_op._set_attr(
                            'op_namescope', op_namescope
                        )  # for recompute
435
                        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
436 437
                            cast_op, ref_mesh, ref_mapping, dist_context
                        )
438 439 440
                        num_cast_ops += 1

                    op._rename_input(in_var.name, cast_name)
441 442 443
                    consume_op_attr.set_input_dist_attr(
                        cast_name, in_var_dist_attr
                    )
444 445 446 447 448 449

        if op.has_attr('out_dtype') and op.attr('out_dtype') != -1:
            assert op.attr('out_dtype') == dst_dtype

        return num_cast_ops

450 451 452
    def _insert_backward_cast_ops(
        self, op, idx, block, src_dtype, dst_dtype, dist_context
    ):
453 454 455

        num_cast_ops = 0
        op_id = op.desc.id()
456
        original_id = op.desc.original_id()
457
        dist_op_context = dist_context.dist_op_context
458
        forward_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
459 460 461 462 463 464 465 466 467

        grad_op_attr = dist_context.get_op_dist_attr_for_program(op)
        assert grad_op_attr is not None

        for out_var_name in op.output_arg_names:
            out_var = block.var(out_var_name)
            if _keep_fp32_output(op, out_var.name):
                continue
            assert out_var.dtype == dst_dtype, "{}, {}".format(
468 469
                str(out_var), dst_dtype
            )
470

471 472 473 474 475 476 477
        for (
            cast_name,
            src_name,
            dst_dtype,
            src_dtype,
            slot_name,
        ) in self.forward_input_cast_ops[forward_op_id]:
478

479 480 481 482
            # some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
            if slot_name not in op.input_names:
                continue

483 484
            # rename input
            assert src_name in op.input(
485 486
                slot_name
            ), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op))
487 488 489 490 491 492 493
            src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
            assert src_var_dist_attr is not None
            op._rename_input(src_name, cast_name)
            grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr)

            # create cast grad
            grad_slot_name = slot_name + "@GRAD"
494 495
            if grad_slot_name not in op.output_names:
                continue
496 497

            # some forward input maybe stop_gradient=True, e.g. input_mask
498 499
            if len(op.output(grad_slot_name)) == 0:
                continue
500 501 502
            assert (
                len(op.output(grad_slot_name)) == 1
            ), "[{}], Current Op: {}".format(grad_slot_name, str(op))
503 504 505 506 507 508 509 510
            grad_name = op.output(grad_slot_name)[0]
            grad = block.var(grad_name)
            grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name)
            assert grad_dist_attr is not None, "{}".format(grad_name)
            ref_mesh = grad_dist_attr.process_mesh
            ref_mapping = grad_dist_attr.dims_mapping

            cast_grad = block.create_var(
511 512 513
                name=unique_name.generate_with_ignorable_key(
                    "".join([cast_name, '@GRAD'])
                ),
514 515 516 517
                dtype=dst_dtype,
                shape=grad.shape,
                type=grad.type,
                persistable=grad.persistable,
518 519
                stop_gradient=grad.stop_gradient,
            )
520
            dist_context.set_tensor_dist_attr_for_program(
521 522
                cast_grad, grad_dist_attr
            )
523 524 525 526 527 528 529 530 531 532 533 534
            op._rename_output(grad_name, cast_grad.name)
            grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr)

            # add cast
            cast_op = block._insert_op_without_sync(
                idx + 1,
                type="cast",
                inputs={"X": [cast_grad.name]},
                outputs={"Out": [grad.name]},
                attrs={
                    "in_dtype": dst_dtype,
                    "out_dtype": src_dtype,
535 536 537
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
538 539 540
            grad.desc.set_dtype(src_dtype)

            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
541 542
                cast_op, ref_mesh, ref_mapping, dist_context
            )
543 544 545 546 547 548 549 550 551 552 553 554
            num_cast_ops += 1

        return num_cast_ops


def _check_and_update_gradient(grads, loss_scaling, name, dist_context):

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
    for e in grads:
555 556 557 558 559 560
        check_variable_and_dtype(
            e,
            "x",
            ['float16', 'float32', 'float64'],
            'check_finite_and_unscale',
        )
561 562

    found_inf = main_block.create_var(
563 564 565
        name=unique_name.generate_with_ignorable_key(
            ".".join(['find_infinite_scale', name])
        ),
566 567 568 569
        shape=[1],
        dtype='bool',
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
570 571
        stop_gradient=False,
    )
572 573 574 575
    set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)

    inputs = {'X': grads, 'Scale': loss_scaling}
    outputs = {'Out': grads, 'FoundInfinite': found_inf}
576
    attrs = {'op_role': OpRole.Optimize}
577 578 579 580 581 582
    new_op = main_block.append_op(
        type='check_finite_and_unscale',
        inputs=inputs,
        outputs=outputs,
        attrs=attrs,
    )
583

584 585 586 587
    # Constructing dist attr from op_desc can
    # give all inputs and outputs default dist attrs
    new_op_dist_attr = OperatorDistAttr(new_op.desc)
    new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
588 589 590 591 592 593
    new_op_dist_attr.impl_idx = 0
    if len(world_process_group.ranks) > 1:
        new_op_dist_attr.impl_type = "check_finite_and_unscale"
    for g in grads:
        g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
        assert g_dist_attr is not None
594 595 596 597 598 599
        new_op_dist_attr.set_input_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
        new_op_dist_attr.set_output_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
600 601 602 603 604 605 606 607
    dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
    return grads, found_inf


def _split_grads(params_grads):
    grads = [g for _, g in params_grads]
    fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
    fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
608 609 610
    assert len(fp32_grads) + len(fp16_grads) == len(
        grads
    ), "Data types of all grads must be either fp16 or fp32."
611 612 613 614
    return grads, fp32_grads, fp16_grads


def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context):
615 616
    new_op_dist_attr = OperatorDistAttr()
    new_op_dist_attr.process_mesh = ProcessMesh(ranks)
617 618 619 620 621
    new_op_dist_attr.impl_idx = 0
    for var_name in new_op.input_arg_names:
        var = block.var(var_name)
        var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        assert var_dist_attr is not None
622 623 624
        new_op_dist_attr.set_input_dims_mapping(
            var_name, var_dist_attr.dims_mapping
        )
625 626 627 628
    for var_name in new_op.output_arg_names:
        var = block.var(var_name)
        var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        assert var_dist_attr is not None
629 630 631
        new_op_dist_attr.set_output_dims_mapping(
            var_name, var_dist_attr.dims_mapping
        )
632 633 634
    dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)


635 636 637
def _get_memcopy_idx(block, found_inf_var):
    # use reduce_any op for check_nan_inf as the anchor for now
    for idx, op in enumerate(block.ops):
638 639 640 641
        if (
            op.type == 'reduce_any'
            and op.output_arg_names[0] == found_inf_var.name
        ):
642 643 644
            return idx + 1

    raise RuntimeError(
645 646
        "not found the correct location for memcopy for found_inf_var."
    )
647 648 649 650


def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
    src_name = src_var.name
651 652 653 654 655 656 657 658 659 660
    output_var = block.create_var(
        name=unique_name.generate_with_ignorable_key(
            src_name.join(['memcopy_'])
        ),
        dtype=src_var.dtype,
        shape=src_var.shape,
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
        stop_gradient=src_var.stop_gradient,
    )
661 662 663 664 665 666 667 668

    set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks)

    # TODO to support CUDAPinned/NPU/XPU Places
    if direction == "D2H":
        dst_place_type = 0
    else:
        raise NotImplementedError(
669 670
            "direction [{}] is not supported yet.".format(direction)
        )
671 672

    attrs = {'dst_place_type': dst_place_type}
673 674
    new_op = block._insert_op_without_sync(
        index=idx,
675
        type='memcpy_d2h',
676 677 678 679 680 681 682
        inputs={'X': [src_var]},
        outputs={'Out': [output_var]},
        attrs=attrs,
    )
    _set_op_dist_attr_with_ranks(
        new_op, world_process_group.ranks, block, dist_context
    )
683 684 685 686
    block._sync_with_cpp()
    return output_var


687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709
def cast_startup_program():
    main_program = default_main_program()
    startup_program = default_startup_program()

    param_to_dtype = {}
    for block in main_program.blocks:
        for p in block.all_parameters():
            param_to_dtype[p.name] = p.dtype

    def is_initialization_op(op):
        comm_op_prefix = "c_"
        op_type = op.type
        if op_type.startswith(comm_op_prefix):
            return False

        if len(op.output_arg_names) != 1 and len(op.input_arg_names) != 0:
            return False

        return True

    for op in startup_program.global_block().ops:
        if is_initialization_op(op):
            output_name = op.output_arg_names[0]
710 711 712 713
            if (
                param_to_dtype.get(output_name, None)
                == core.VarDesc.VarType.FP16
            ):
714 715 716
                assert op.has_attr(
                    'dtype'
                ), "initialization op is supported to has dtype attribute but got {}.".format(
717 718
                    str(op)
                )
719 720 721 722
                if op.attr('dtype') == core.VarDesc.VarType.FP32:
                    op._set_attr('dtype', core.VarDesc.VarType.FP16)


723 724 725
@register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass):
    def __init__(self):
726
        super().__init__()
727

728 729
    # NOTE: why FP16Pass can override apply_single_impl instead of
    # apply_impl? AMP is an optimization pass for serial program,
730 731 732 733 734 735 736
    # in distributed scenario, all ranks should have the same modification.
    def _apply_single_impl(self, main_program, startup_program, context):
        self.dist_context = self.get_attr("dist_context")
        params_grads = self.get_attr("params_grads")

        amp_list = AutoMixedPrecisionLists(
            set(self.get_attr("custom_white_list")),
737 738 739
            set(self.get_attr("custom_black_list")),
            None,
        )
740

741
        # NOTE don't not change input data dtype, since it is controled by dataloader
742 743 744
        # and which is out of control of FP16 Pass
        input_data_var_names = [var.name for var in self.get_attr("input_data")]

745
        with paddle.static.program_guard(main_program, startup_program):
746 747 748 749 750 751 752
            fp16_state = FP16State(
                main_program,
                amp_list,
                self.dist_context,
                self.get_attr("use_fp16_guard"),
                input_data_var_names,
            )
753 754
            is_train = fp16_state._build_state()

755 756
            cast_startup_program()

757 758
        if is_train:
            with paddle.static.program_guard(main_program, startup_program):
759
                # TODO (JZ-LIANG)support cast forward program only when inference
760 761 762 763 764
                self._init_amp_var()
                self._scale_loss()

                grads, fp32_grads, fp16_grads = _split_grads(params_grads)

765 766 767 768
                if (
                    self.get_attr("use_dynamic_loss_scaling")
                    or self.get_attr("init_loss_scaling") != 1.0
                ):
769 770
                    found_infs = []
                    if fp32_grads:
771
                        with main_program._optimized_guard([]):
772
                            _, found_inf_fp32 = _check_and_update_gradient(
773 774 775 776 777
                                fp32_grads,
                                self._loss_scaling,
                                "@fp32",
                                self.dist_context,
                            )
778 779
                        found_infs.append(found_inf_fp32)
                    if fp16_grads:
780
                        with main_program._optimized_guard([]):
781
                            _, found_inf_fp16 = _check_and_update_gradient(
782 783 784 785 786
                                fp16_grads,
                                self._loss_scaling,
                                "@fp16",
                                self.dist_context,
                            )
787
                        found_infs.append(found_inf_fp16)
788
                    with main_program._optimized_guard([]):
789 790
                        block = main_program.global_block()

791 792
                        # all_infs = paddle.fluid.layers.concat(found_infs)
                        all_infs = block.create_var(
793
                            name=paddle.utils.unique_name.generate_with_ignorable_key(
794 795 796 797 798 799 800 801 802 803 804 805 806 807 808
                                ".".join(['concat', 'tmp'])
                            ),
                            dtype=found_infs[0].dtype,
                            shape=None,
                            lod_level=found_infs[0].lod_level,
                            type=found_infs[0].type,
                            persistable=False,
                            stop_gradient=False,
                        )
                        concat_op = block.append_op(
                            type='concat',
                            inputs={'X': found_infs},
                            outputs={'Out': [all_infs]},
                            attrs={'axis': 0},
                        )
809 810 811 812 813 814 815
                        set_var_dist_attr(
                            self.dist_context,
                            all_infs,
                            [-1],
                            world_process_group.ranks,
                        )
                        _set_op_dist_attr_with_ranks(
816
                            concat_op,
817 818 819 820
                            world_process_group.ranks,
                            block,
                            self.dist_context,
                        )
821

822 823
                        # found_inf = paddle.fluid.layers.reduce_any(all_infs)
                        found_inf = block.create_var(
824
                            name=paddle.utils.unique_name.generate_with_ignorable_key(
825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843
                                ".".join(['reduce_any', 'tmp'])
                            ),
                            dtype=all_infs.dtype,
                            shape=None,
                            lod_level=all_infs.lod_level,
                            type=all_infs.type,
                            persistable=False,
                            stop_gradient=False,
                        )
                        reduce_any_op = block.append_op(
                            type='reduce_any',
                            inputs={'X': all_infs},
                            outputs={'Out': found_inf},
                            attrs={
                                'dim': [0],
                                'keep_dim': False,
                                'reduce_all': True,
                            },
                        )
844 845 846 847 848 849 850
                        set_var_dist_attr(
                            self.dist_context,
                            found_inf,
                            [-1],
                            world_process_group.ranks,
                        )
                        _set_op_dist_attr_with_ranks(
851
                            reduce_any_op,
852 853 854 855
                            world_process_group.ranks,
                            block,
                            self.dist_context,
                        )
856 857

                if self.get_attr("use_dynamic_loss_scaling"):
858
                    with main_program._optimized_guard([]):
859 860 861 862 863 864 865 866 867 868
                        if fp32_grads:
                            self._update_loss_scaling(fp32_grads, found_inf)
                        if fp16_grads:
                            self._update_loss_scaling(fp16_grads, found_inf)

            # modify optimizer
            base_opt = self.get_attr("base_opt")
            base_opt._multi_precision = True
            if self.get_attr("use_optimizer_fp16"):
                base_opt._multi_precision = False
869
            if isinstance(
870 871
                base_opt,
                (paddle.static.Adam, paddle.optimizer.AdamW),
872
            ):
873 874 875 876
                with main_program._optimized_guard([]):
                    # found_inf = paddle.tensor.creation._memcpy(
                    #     found_inf, paddle.CPUPlace())
                    insert_idx = _get_memcopy_idx(block, found_inf)
877 878 879
                    found_inf = _insert_memcopy(
                        block, insert_idx, found_inf, self.dist_context
                    )
880 881 882
                base_opt._set_auxiliary_var('found_inf', found_inf.name)
            elif hasattr(base_opt, "_set_auxiliary_var"):
                base_opt._set_auxiliary_var('found_inf', found_inf.name)