auto_parallel_amp.py 46.1 KB
Newer Older
J
JZ-LIANG 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
J
JZ-LIANG 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
J
JZ-LIANG 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
J
JZ-LIANG 已提交
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16 17 18 19
from paddle.distributed.auto_parallel.static.dist_attribute import (
    OperatorDistAttr,
)
from paddle.distributed.auto_parallel.static.process_group import (
20 21
    get_world_process_group,
)
22
from paddle.distributed.auto_parallel.static.utils import (
23 24
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
    set_var_dist_attr,
25
)
26
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
27 28
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.framework import core
29 30 31 32
from paddle.static.amp.bf16.amp_utils import (
    AutoMixedPrecisionListsBF16,
    _is_in_fp32_varnames,
)
33
from paddle.static.amp.fp16_utils import (
34 35
    AutoMixedPrecisionLists,
    _is_in_black_varnames,
36 37
    _keep_fp32_input,
    _keep_fp32_output,
38
    _rename_arg,
39
    _valid_types,
40
    find_op_index,
41 42 43
    find_true_post_op,
    find_true_prev_op,
)
44
from paddle.utils import unique_name
45

46
from ..auto_parallel.process_mesh import ProcessMesh
47
from ..auto_parallel.static.utils import (
48 49 50 51 52 53
    is_backward_op,
    is_forward_op,
    is_loss_grad_op,
    is_loss_op,
    is_optimize_op,
)
54
from .pass_base import PassBase, register_pass
55

Z
zhaoyingli 已提交
56
world_process_group = get_world_process_group()
J
JZ-LIANG 已提交
57

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
__amp_skip_ops__ = [
    'create_py_reader',
    'create_double_buffer_reader',
    'cast',
    'while',
]


def _dtype_to_str(dtype):
    if dtype == core.VarDesc.VarType.FP16:
        return 'fp16'
    elif dtype == core.VarDesc.VarType.BF16:
        return 'bf16'
    else:
        return 'fp32'


def _str_to_dtype(dstr):
    if dstr == 'float16':
        return core.VarDesc.VarType.FP16
    elif dstr == 'bfloat16':
        return core.VarDesc.VarType.BF16
    else:
        return core.VarDesc.VarType.FP32


class AMPLists:
    def __init__(
        self,
        white_list=None,
        black_list=None,
        black_varnames=None,
        dtype="float16",
    ):
        self._amp_list = None
        if dtype == "float16":
            self._amp_list = AutoMixedPrecisionLists(
                set(white_list), set(black_list), set(black_varnames)
            )
        elif dtype == "bfloat16":
            self._amp_list = AutoMixedPrecisionListsBF16(
                set(white_list), set(black_list), set(black_varnames)
            )

        assert self._amp_list is not None
        self._dtype = dtype
        self._is_float16 = dtype == "float16"

    @property
    def white_list(self):
        if self._is_float16:
            return self._amp_list.white_list
        else:
            return self._amp_list.bf16_list

    @property
    def black_list(self):
        if self._is_float16:
            return self._amp_list.black_list
        else:
            return self._amp_list.fp32_list

    @property
    def gray_list(self):
        return self._amp_list.gray_list

    @property
    def black_varnames(self):
        if self._is_float16:
            return self._amp_list.black_varnames
        else:
            return self._amp_list.fp32_varnames

    @property
    def is_fp16(self):
        return self._is_float16

    @property
    def dtype(self):
        return self._dtype

    @property
    def amp_list(self):
        return self._amp_list

    def _is_in_black_fp32_varnames(self, op):
        if self._is_float16:
            return _is_in_black_varnames(op, self._amp_list)
        else:
            return _is_in_fp32_varnames(op, self._amp_list)

    def _op_keep_fp32_input(self, op, in_name):
        if self._is_float16:
            return _keep_fp32_input(op, in_name)
        else:
            if op.type in ['batch_norm', 'layer_norm']:
                return in_name != 'X'
            if op.type == 'fused_bn_add_activation':
                return in_name not in {'X', 'Z'}
            return False

    def _op_keep_fp32_output(self, op, out_name):
        if self._is_float16:
            return _keep_fp32_output(op, out_name)
        else:
            if op.type in [
                'batch_norm',
                'fused_bn_add_activation',
                'layer_norm',
            ]:
                return out_name != 'Y'
            return False

J
JZ-LIANG 已提交
171

172
class AMPState:
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    def __init__(self, program, amp_lists, amp_dtype, dist_context):
        self.program = program
        self.dist_context = dist_context
        self.amp_lists = amp_lists
        self.amp_dtype = amp_dtype
        self.grad_op_to_op_map = (
            dist_context.dist_op_context.grad_op_id_to_op_id
        )

        # op_id --> True/False. 'True' means that the current op is in fp16/bf16 mode.
        self._op_fp16_dict = {}
        # fwd_op_id --> {old_name: cast_name}
        self._var_name_dict = {}
        # out_var_name --> [op_ids]
        self.out_var_op_deps = {}
J
JZ-LIANG 已提交
188 189 190 191

    def _is_fp16_op(self, op_id):
        return self._op_fp16_dict.get(op_id, None)

192 193 194 195 196 197 198 199 200 201 202 203
    def build_state(self):
        is_train = False
        for block in self.program.blocks:
            for op in block.ops:
                # to record the inplace operation and their outputs
                for name in op.output_arg_names:
                    if name not in self.out_var_op_deps:
                        self.out_var_op_deps[name] = [op.desc.original_id()]
                    else:
                        self.out_var_op_deps[name].extend(
                            [op.desc.original_id()]
                        )
J
JZ-LIANG 已提交
204

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
                if is_loss_grad_op(op):
                    is_train = True

                if op.type in __amp_skip_ops__:
                    continue

                if is_forward_op(op):
                    self._mark_black_white_ops(op, block.ops, block)
                elif is_backward_op(op):
                    if op.desc.original_id() in self.grad_op_to_op_map:
                        fwd_op_id = self.grad_op_to_op_map[
                            op.desc.original_id()
                        ]
                        assert fwd_op_id in self._op_fp16_dict, "{}".format(
                            str(op)
                        )
                        self._op_fp16_dict[
                            op.desc.original_id()
                        ] = self._is_fp16_op(fwd_op_id)
                elif is_optimize_op(op):
                    break

        # insert cast ops
        for block in self.program.blocks:
            self._cast_block(block)

        return is_train

    def _mark_black_white_ops(self, op, ops, block):

        # ernie inference trick
        if op.type == "assign" and "array_" in op.input_arg_names[0]:
            self._op_fp16_dict[op.desc.original_id()] = False
            return

        # If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
        if op.type == "assign":
            out_name = op.output_arg_names[0]
            if len(self.out_var_op_deps[out_name]) > 1:
                if not self._is_fp16_op(self.out_var_op_deps[out_name][0]):
245
                    self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
246
                else:
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
                    self._op_fp16_dict[op.desc.original_id()] = True
                return

        if (
            self.amp_lists.black_varnames is not None
            and self.amp_lists._is_in_black_fp32_varnames(op)
        ):
            self._op_fp16_dict[op.desc.original_id()] = False
            return
        if op.type in self.amp_lists.black_list:
            self._op_fp16_dict[op.desc.original_id()] = False
        elif op.type in self.amp_lists.white_list:
            self._op_fp16_dict[op.desc.original_id()] = True
        elif op.type in self.amp_lists.gray_list:
            is_black_op = False
            is_white_op = False
            for in_name in op.input_names:
                # if this op has inputs
                if in_name:
                    for in_var_name in op.input(in_name):
                        in_var = block._var_recursive(in_var_name)
                        # this in_var isn't the output of other op
                        if in_var.op is None:
                            continue
                        elif in_var.op is op:
                            prev_op = find_true_prev_op(ops, op, in_var_name)
                            if prev_op is None:
                                continue
                        else:
                            prev_op = in_var.op
                        # if it's one of inputs
                        if (
                            self._is_fp16_op(prev_op.desc.original_id())
                            is False
                            or prev_op.type in self.amp_lists.black_list
                        ):
                            is_black_op = True
                        elif (
                            self._is_fp16_op(prev_op.desc.original_id()) is True
                            or prev_op.type in self.amp_lists.white_list
                        ):
                            is_white_op = True
            if is_black_op:
290
                self._op_fp16_dict[op.desc.original_id()] = False
291 292 293 294 295 296 297 298
            elif is_white_op:
                self._op_fp16_dict[op.desc.original_id()] = True
            else:
                pass
        else:
            # For numerical safe, we apply fp32 computation on ops that
            # are not determined which list they should stay.
            self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
299

300
    def _cast_block(self, block):
J
JZ-LIANG 已提交
301
        idx = 0
302 303 304
        appended_grad_times = 0
        while idx < len(block.ops):
            op = block.ops[idx]
J
JZ-LIANG 已提交
305
            num_cast_ops = 0
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343

            if op.type in __amp_skip_ops__:
                idx += 1
                continue

            elif is_forward_op(op):
                if self._is_fp16_op(op.desc.original_id()) is False:
                    num_cast_ops = self._insert_cast_op_forward(
                        block,
                        op,
                        idx,
                        _str_to_dtype(self.amp_dtype),
                        core.VarDesc.VarType.FP32,
                        self.dist_context,
                    )
                elif self._is_fp16_op(op.desc.original_id()) is True:
                    if self.amp_dtype == "bfloat16":
                        if op.has_attr('use_mkldnn'):
                            op._set_attr('use_mkldnn', True)
                            op._set_attr('mkldnn_data_type', 'bfloat16')
                        elif (
                            op.has_attr('dtype')
                            and op.attr('dtype') == core.VarDesc.VarType.FP32
                        ):
                            op._set_attr('dtype', core.VarDesc.VarType.BF16)
                    num_cast_ops = self._insert_cast_op_forward(
                        block,
                        op,
                        idx,
                        core.VarDesc.VarType.FP32,
                        _str_to_dtype(self.amp_dtype),
                        self.dist_context,
                    )
            elif is_backward_op(op):
                # NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
                # which will affect the dist_op to insert allreduce_sum op.
                op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
                    op
344
                )
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
                if is_backward_op(op) and (
                    is_forward_op(block.ops[idx - 1])
                    or is_loss_op(block.ops[idx - 1])
                ):
                    if not op_dist_attr.is_recompute:
                        appended_grad_times += 1

                if op.desc.original_id() in self.grad_op_to_op_map:
                    if self._is_fp16_op(op.desc.original_id()) is False:  # fp32
                        num_cast_ops = self._insert_cast_op_backward(
                            block,
                            op,
                            idx,
                            _str_to_dtype(self.amp_dtype),
                            core.VarDesc.VarType.FP32,
                            self.dist_context,
                            appended_grad_times,
                        )
                    elif (
                        self._is_fp16_op(op.desc.original_id()) is True
                    ):  # fp16/bf16
                        if self.amp_dtype == "bfloat16":
                            if op.has_attr('use_mkldnn'):
                                op._set_attr('use_mkldnn', True)
                                op._set_attr('mkldnn_data_type', 'bfloat16')
                            elif (
                                op.has_attr('dtype')
                                and op.attr('dtype')
                                == core.VarDesc.VarType.FP32
                            ):
                                op._set_attr('dtype', core.VarDesc.VarType.BF16)
                        num_cast_ops = self._insert_cast_op_backward(
                            block,
                            op,
                            idx,
                            core.VarDesc.VarType.FP32,
                            _str_to_dtype(self.amp_dtype),
                            self.dist_context,
                            appended_grad_times,
                        )
                elif op.type == "sum":
                    # all inputs dtype of sum should be equal and output dtype should follow input
                    out_var_name = op.desc.output_arg_names()[0]
                    in_var_name = op.desc.input_arg_names()[0]
                    out_var = block.var(out_var_name)
                    in_var = block._find_var_recursive(in_var_name)
                    for in_var_name in op.input_arg_names:
                        assert (
                            in_var.dtype == block.var(in_var_name).dtype
                        ), "{}, {}, {}".format(
                            in_var, block.var(in_var_name), str(op)
                        )
                    out_var.desc.set_dtype(in_var.dtype)
                elif int(op.attr('op_role')) == 257:
                    pass
                else:
                    raise ValueError(
                        "'{}' op is not supported in the complete amp pass.".format(
                            op.type
                        )
                    )
J
JZ-LIANG 已提交
406
            idx += num_cast_ops + 1
407
        block._sync_with_cpp()
J
JZ-LIANG 已提交
408

409
    def _insert_cast_op_forward(
410
        self, block, op, idx, src_dtype, dst_dtype, dist_context
411
    ):
J
JZ-LIANG 已提交
412 413
        """
        only for forward cast
414
        modified from paddle.static.amp
J
JZ-LIANG 已提交
415 416
        """
        num_cast_ops = 0
417
        var_name_dict = {}
J
JZ-LIANG 已提交
418
        for in_name in op.input_names:
419 420 421
            if (
                src_dtype == core.VarDesc.VarType.FP32
                and self.amp_lists._op_keep_fp32_input(op, in_name)
422
            ):
J
JZ-LIANG 已提交
423 424
                continue
            for in_var_name in op.input(in_name):
425
                in_var = block._find_var_recursive(in_var_name)
J
JZ-LIANG 已提交
426 427 428
                if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
                    continue
                if in_var.dtype == src_dtype:
429 430 431
                    cast_name = (
                        in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
                    )
432
                    cast_var = block.vars.get(cast_name)
J
JZ-LIANG 已提交
433 434
                    var_name_dict[in_var.name] = cast_name
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
435 436
                        op
                    )
J
JZ-LIANG 已提交
437
                    assert consume_op_attr is not None
438
                    if cast_var is None or cast_var.dtype != dst_dtype:
J
JZ-LIANG 已提交
439 440 441
                        # NOTE we make the cast op and var's dist attr as the op that consume the
                        # cast var instead of the op which generates the var
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
442 443
                            in_var.name
                        )
J
JZ-LIANG 已提交
444 445 446
                        assert in_var_dist_attr is not None
                        ref_mesh = in_var_dist_attr.process_mesh
                        ref_mapping = in_var_dist_attr.dims_mapping
447
                        consume_op_attr.set_input_dist_attr(
448 449
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
450

451
                        cast_var = block.create_var(
J
JZ-LIANG 已提交
452 453 454
                            name=cast_name,
                            dtype=dst_dtype,
                            persistable=False,
455 456 457
                            stop_gradient=in_var.stop_gradient,
                        )
                        set_var_dist_attr(
458
                            dist_context, cast_var, ref_mapping, ref_mesh
459
                        )
J
JZ-LIANG 已提交
460

461 462 463
                        op_namescope = "/"
                        if op.has_attr('op_namescope'):
                            op_namescope = op.attr('op_namescope')
464
                        cast_op = block._insert_op_without_sync(
J
JZ-LIANG 已提交
465 466 467
                            idx,
                            type="cast",
                            inputs={"X": in_var},
468
                            outputs={"Out": cast_var},
J
JZ-LIANG 已提交
469 470
                            attrs={
                                "in_dtype": in_var.dtype,
471
                                "out_dtype": cast_var.dtype,
472 473
                            },
                        )
474 475 476
                        cast_op._set_attr(
                            'op_namescope', op_namescope
                        )  # for recompute
J
JZ-LIANG 已提交
477
                        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
478 479
                            cast_op, ref_mesh, ref_mapping, dist_context
                        )
J
JZ-LIANG 已提交
480 481 482
                        num_cast_ops += 1
                    else:
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
483 484
                            in_var.name
                        )
485
                        consume_op_attr.set_input_dist_attr(
486 487
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
488 489 490 491
                    _rename_arg(op, in_var.name, cast_name)
                else:
                    if op.has_attr('in_dtype'):
                        op._set_attr('in_dtype', dst_dtype)
492
        self._var_name_dict[op.desc.original_id()] = var_name_dict
J
JZ-LIANG 已提交
493

494 495
        if (
            src_dtype == core.VarDesc.VarType.FP32
496
            and dst_dtype == _str_to_dtype(self.amp_dtype)
497
        ):
J
JZ-LIANG 已提交
498
            for out_name in op.output_names:
499
                if self.amp_lists._op_keep_fp32_output(op, out_name):
J
JZ-LIANG 已提交
500 501
                    continue
                for out_var_name in op.output(out_name):
502
                    out_var = block._var_recursive(out_var_name)
J
JZ-LIANG 已提交
503 504 505
                    if out_var.type not in _valid_types:
                        continue
                    if out_var.dtype == core.VarDesc.VarType.FP32:
506
                        out_var.desc.set_dtype(_str_to_dtype(self.amp_dtype))
J
JZ-LIANG 已提交
507
                        if op.has_attr('out_dtype'):
508 509 510
                            op._set_attr(
                                'out_dtype', _str_to_dtype(self.amp_dtype)
                            )
J
JZ-LIANG 已提交
511 512
        return num_cast_ops

513 514
    def _insert_cast_op_backward(
        self,
515 516
        block,
        op,
517 518 519 520 521 522 523
        idx,
        src_dtype,
        dst_dtype,
        dist_context,
        appended_grad_times,
    ):
        """only for backward cast"""
J
JZ-LIANG 已提交
524 525 526 527 528 529 530 531 532 533 534 535 536 537

        def _keep_fp32_input(op, in_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return in_name not in {'X', 'Y@GRAD'}
            return False

        def _keep_fp32_output(op, out_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return out_name != 'X@GRAD'
            return False

        num_cast_ops = 0
538
        original_id = op.desc.original_id()
J
JZ-LIANG 已提交
539
        dist_op_context = dist_context.dist_op_context
540
        fwd_op_id = self.grad_op_to_op_map[original_id]
J
JZ-LIANG 已提交
541

542
        for in_name in op.input_names:
J
JZ-LIANG 已提交
543
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
544
                op, in_name
545
            ):
546 547
                for in_var_name in op.input(in_name):
                    in_var = block._var_recursive(in_var_name)
J
JZ-LIANG 已提交
548 549 550
                    assert in_var.dtype == core.VarDesc.VarType.FP32
                continue

551 552
            for in_var_name in op.input(in_name):
                in_var = block._var_recursive(in_var_name)
J
JZ-LIANG 已提交
553 554
                if in_var.dtype == src_dtype:
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
555
                        op
556
                    )
J
JZ-LIANG 已提交
557 558 559 560
                    if in_var_name in self._var_name_dict[fwd_op_id]:
                        # NOTE: if in_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr.
                        cast_name = self._var_name_dict[fwd_op_id][in_var_name]
561
                        op.desc._rename_input(in_var_name, cast_name)
J
JZ-LIANG 已提交
562
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
563 564
                            in_var_name
                        )
565
                        consume_op_attr.set_input_dist_attr(
566 567
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
568
                    else:
569 570 571
                        assert (
                            in_var.dtype == dst_dtype
                        ), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
572
                            op.type,
573 574 575
                            in_name,
                            dst_dtype,
                            in_var.dtype,
576
                            str(op),
577
                        )
J
JZ-LIANG 已提交
578

579
        for out_name in op.output_names:
J
JZ-LIANG 已提交
580
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
581
                op, out_name
582
            ):
583 584
                for out_var_name in op.output(out_name):
                    out_var = block._var_recursive(out_var_name)
J
JZ-LIANG 已提交
585 586 587
                    assert out_var.dtype == core.VarDesc.VarType.FP32
                continue

588 589
            for out_var_name in op.output(out_name):
                out_var = block._var_recursive(out_var_name)
590
                out_var_name_prefix = out_var_name[: out_var_name.find("@")]
591
                fwd_var = block._var_recursive(out_var_name_prefix)
J
JZ-LIANG 已提交
592 593 594 595 596 597 598 599 600
                # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
                if out_var.dtype != fwd_var.dtype:
                    out_var.desc.set_dtype(fwd_var.dtype)

                if out_var.dtype == src_dtype:
                    if out_var_name_prefix in self._var_name_dict[fwd_op_id]:
                        # NOTE: if out_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr, then we insert cast op to
                        # convert the cast_var to original dtype
601
                        consume_op_attr = (
602
                            dist_context.get_op_dist_attr_for_program(op)
603
                        )
J
JZ-LIANG 已提交
604
                        fwd_cast_name = self._var_name_dict[fwd_op_id][
605 606
                            out_var_name_prefix
                        ]
607 608
                        suffix = ""
                        if "@RENAME" in out_var_name:
609 610 611
                            suffix = out_var_name[
                                out_var_name.find("@RENAME") :
                            ]
612
                        cast_name = fwd_cast_name + "@GRAD" + suffix
613
                        cast_var = block.vars.get(cast_name)
J
JZ-LIANG 已提交
614
                        if cast_var is None or cast_var.dtype != dst_dtype:
615
                            op.desc._rename_output(out_var_name, cast_name)
616 617 618 619 620
                            out_var_dist_attr = (
                                consume_op_attr.get_output_dist_attr(
                                    out_var_name
                                )
                            )
J
JZ-LIANG 已提交
621 622 623
                            ref_mesh = out_var_dist_attr.process_mesh
                            ref_mapping = out_var_dist_attr.dims_mapping
                            consume_op_attr.set_output_dist_attr(
624 625
                                cast_name, out_var_dist_attr
                            )
J
JZ-LIANG 已提交
626
                            assert ref_mapping is not None
627
                            cast_var = block.create_var(
J
JZ-LIANG 已提交
628 629 630 631
                                name=cast_name,
                                shape=out_var.shape,
                                dtype=dst_dtype,
                                persistable=False,
632 633 634 635 636
                                stop_gradient=out_var.stop_gradient,
                            )
                            set_var_dist_attr(
                                dist_context, cast_var, ref_mapping, ref_mesh
                            )
637
                            dist_op_context.grad_var_to_var[
638 639
                                appended_grad_times
                            ][cast_name] = fwd_cast_name
J
JZ-LIANG 已提交
640

641
                            cast_op = block._insert_op(
J
JZ-LIANG 已提交
642 643 644 645 646 647 648
                                idx + 1,
                                type="cast",
                                inputs={"X": cast_var},
                                outputs={"Out": out_var},
                                attrs={
                                    "in_dtype": cast_var.dtype,
                                    "out_dtype": out_var.dtype,
649 650 651
                                    "op_role": OpRole.Backward,
                                },
                            )
J
JZ-LIANG 已提交
652 653 654 655
                            cast_op._remove_attr("op_role_var")
                            cast_op._remove_attr("op_namescope")
                            cast_op._remove_attr("with_quant_attr")
                            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
656 657
                                cast_op, ref_mesh, ref_mapping, dist_context
                            )
J
JZ-LIANG 已提交
658 659 660 661 662 663 664 665 666 667
                            num_cast_ops += 1
                else:
                    assert out_var.dtype == dst_dtype

        return num_cast_ops


@register_pass("auto_parallel_amp")
class AMPPass(PassBase):
    def __init__(self):
668
        super().__init__()
669
        self.set_attr("dtype", "")  # fp16/bf16
J
JZ-LIANG 已提交
670 671 672 673 674 675 676 677 678 679 680
        self.set_attr("loss", None)
        self.set_attr("dist_context", None)
        self.set_attr("custom_white_list", None)
        self.set_attr("custom_black_list", None)
        self.set_attr("custom_black_varnames", None)
        self.set_attr("init_loss_scaling", 32768.0)
        self.set_attr("incr_every_n_steps", 1000)
        self.set_attr("decr_every_n_nan_or_inf", 2)
        self.set_attr("incr_ratio", 2.0)
        self.set_attr("decr_ratio", 0.8)
        self.set_attr("use_dynamic_loss_scaling", False)
681
        self.set_attr("input_data", [])
J
JZ-LIANG 已提交
682
        self.set_attr("params_grads", [])
683
        self.set_attr("dtype", "")  # fp16/bf16
684
        self._loss = None
J
JZ-LIANG 已提交
685 686 687 688 689
        self._loss_scaling = None
        self._num_good_steps = None
        self._num_bad_steps = None

    def _check_self(self):
690 691
        if self.get_attr("dtype") not in ["float16", "bfloat16"]:
            return False
J
JZ-LIANG 已提交
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708
        if self.get_attr("init_loss_scaling") < 0:
            return False
        if self.get_attr("incr_every_n_steps") < 0:
            return False
        if self.get_attr("decr_every_n_nan_or_inf") < 0:
            return False
        if self.get_attr("incr_ratio") < 0:
            return False
        if self.get_attr("decr_ratio") < 0:
            return False
        if self.get_attr("dist_context") is None:
            return False
        return True

    def _check_conflict(self, other_pass):
        return True

709 710
    # NOTE: why AMPBackwardPass can override apply_single_impl instead of
    # apply_impl? AMP is an optimization pass for serial program,
J
JZ-LIANG 已提交
711 712 713
    # in distributed scenario, all ranks should have the same modification.
    def _apply_single_impl(self, main_program, startup_program, context):
        self.dist_context = self.get_attr("dist_context")
714 715
        self.params_grads = self.get_attr("params_grads")
        self.amp_dtype = self.get_attr("dtype")
J
JZ-LIANG 已提交
716

717
        amp_lists = AMPLists(
J
JZ-LIANG 已提交
718 719
            set(self.get_attr("custom_white_list")),
            set(self.get_attr("custom_black_list")),
720
            set(self.get_attr("custom_black_varnames")),
721
            self.amp_dtype,
722
        )
J
JZ-LIANG 已提交
723 724

        with paddle.static.program_guard(main_program, startup_program):
725 726 727 728
            amp_state = AMPState(
                main_program, amp_lists, self.amp_dtype, self.dist_context
            )
            is_train = amp_state.build_state()
Z
zhaoyingli 已提交
729

730 731 732
            if is_train:
                self._update_backward_cast_ops()
                self._cast_loss()
Z
zhaoyingli 已提交
733

734
            if is_train and self.amp_dtype == "float16":
Z
zhaoyingli 已提交
735 736
                self._init_amp_var()
                self._scale_loss()
737 738 739 740
                if (
                    self.get_attr("use_dynamic_loss_scaling")
                    or self.get_attr("init_loss_scaling") != 1.0
                ):
741
                    grads, found_inf = self._check_and_update_gradient()
Z
zhaoyingli 已提交
742 743 744

                if self.get_attr("use_dynamic_loss_scaling"):
                    self._update_loss_scaling(grads, found_inf)
J
JZ-LIANG 已提交
745

746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766
    def _update_backward_cast_ops(self):
        """
        move param grad cast to the end of backward segment
        in order to enabel fp16 allreduce
        """
        # TODO filter optimize ops in future

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()

        for p, g in self.params_grads:
            op = g.op
            if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
                if int(op.attr('op_role')) == int(
                    OpRole.Backward
                ) and op.has_attr('op_role_var'):
                    op._remove_attr("op_role_var")

                post_ops = find_true_post_op(main_block.ops, op, g.name)
                if post_ops:
                    raise ValueError(
767
                        f"The cast op {op}'s output should not be"
768
                        "used by a non-optimize op, however, it"
769
                        f"is used by {post_ops[0]}"
770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810
                    )

                if op == main_block.ops[-1]:
                    continue

                # add new op in the python and cpp at the same time
                new_op_desc = main_block.desc.append_op()
                new_op_desc.copy_from(op.desc)
                new_op = paddle.static.Operator(
                    block=main_block,
                    desc=new_op_desc,
                    type=None,
                    inputs=None,
                    outputs=None,
                    attrs=None,
                )
                main_block.ops.append(new_op)

                # dist attr
                param_dist_attr = (
                    self.dist_context.get_tensor_dist_attr_for_program(p)
                )
                output_dist_attr = (
                    self.dist_context.get_tensor_dist_attr_for_program(
                        main_block.var(op.output_arg_names[0])
                    )
                )
                assert param_dist_attr is not None
                assert output_dist_attr is not None
                naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                    new_op,
                    param_dist_attr.process_mesh,
                    param_dist_attr.dims_mapping,
                    self.dist_context,
                )

                output_dist_attr.process_mesh = param_dist_attr.process_mesh
                output_dist_attr.dims_mapping = param_dist_attr.dims_mapping

                op_idx = find_op_index(main_block.desc, op.desc)
                if op_idx == -1:
811
                    raise ValueError(f"The op {op} is not in program")
812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873
                main_block._remove_op(op_idx, sync=False)

        main_block._sync_with_cpp()

    def _check_and_update_gradient(self):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()

        grads = [g for _, g in self.params_grads]
        check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
        for e in grads:
            check_variable_and_dtype(
                e,
                "x",
                ['float16', 'float32', 'float64'],
                'check_finite_and_unscale',
            )

        found_inf = main_block.create_var(
            name=unique_name.generate_with_ignorable_key(
                ".".join(['find_infinite_scale', 'tmp'])
            ),
            shape=[1],
            dtype='bool',
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=False,
        )
        set_var_dist_attr(
            self.dist_context, found_inf, [-1], world_process_group.ranks
        )

        inputs = {'X': grads, 'Scale': self._loss_scaling}
        outputs = {'Out': grads, 'FoundInfinite': found_inf}
        attrs = {'op_role': OpRole.Optimize}
        new_op = main_block.append_op(
            type='check_finite_and_unscale',
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
        )

        # Constructing dist attr from op_desc can
        # give all inputs and outputs default dist attrs
        new_op_dist_attr = OperatorDistAttr(new_op.desc)
        new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
        new_op_dist_attr.impl_idx = 0
        if len(world_process_group.ranks) > 1:
            new_op_dist_attr.impl_type = "check_finite_and_unscale"
        for g in grads:
            g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
            assert g_dist_attr is not None
            new_op_dist_attr.set_input_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
            new_op_dist_attr.set_output_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
        self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
        return grads, found_inf

J
JZ-LIANG 已提交
874 875 876 877 878 879
    def _init_amp_var(self):
        self._loss_scaling = paddle.static.create_global_var(
            name=unique_name.generate("loss_scaling"),
            shape=[1],
            value=self.get_attr("init_loss_scaling"),
            dtype='float32',
880 881 882 883 884 885 886 887
            persistable=True,
        )
        set_var_dist_attr(
            self.dist_context,
            self._loss_scaling,
            [-1],
            world_process_group.ranks,
        )
J
JZ-LIANG 已提交
888 889 890 891 892 893 894

        if self.get_attr("use_dynamic_loss_scaling"):
            self._num_good_steps = paddle.static.create_global_var(
                name=unique_name.generate("num_good_steps"),
                shape=[1],
                value=0,
                dtype='int32',
895 896 897 898 899 900 901 902
                persistable=True,
            )
            set_var_dist_attr(
                self.dist_context,
                self._num_good_steps,
                [-1],
                world_process_group.ranks,
            )
J
JZ-LIANG 已提交
903 904 905 906 907 908

            self._num_bad_steps = paddle.static.create_global_var(
                name=unique_name.generate("num_bad_steps"),
                shape=[1],
                value=0,
                dtype='int32',
909 910 911 912 913 914 915 916
                persistable=True,
            )
            set_var_dist_attr(
                self.dist_context,
                self._num_bad_steps,
                [-1],
                world_process_group.ranks,
            )
J
JZ-LIANG 已提交
917

918
    def _cast_loss(self):
J
JZ-LIANG 已提交
919 920 921

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()
922

J
JZ-LIANG 已提交
923 924 925 926
        loss = self.get_attr("loss")
        assert loss is not None
        loss_op = loss.op
        loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
927 928
            loss_op
        )
J
JZ-LIANG 已提交
929 930

        if loss.dtype != core.VarDesc.VarType.FP32:
931 932

            tmp_name = unique_name.generate(loss.name + ".cast_fp32")
933 934 935
            cast_loss = main_block.create_var(
                name=tmp_name, dtype=core.VarDesc.VarType.FP32
            )
936
            loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
937 938
                loss
            )
939
            ref_mesh = loss_op_dist_attr.process_mesh
940
            self.dist_context.set_tensor_dist_attr_for_program(
941 942
                cast_loss, loss_dist_attr
            )
943

944
            # forward
945 946 947 948 949 950 951 952 953
            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
            cast_op = main_block._insert_op(
                loss_op_idx + 1,
                type='cast',
                inputs={'X': [loss]},
                outputs={'Out': [cast_loss]},
                attrs={
                    "in_dtype": loss.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
954
                    "op_role": loss_op.all_attrs()[OP_ROLE_KEY],
955 956
                },
            )
957

958
            loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward)
959
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
960
                cast_op, ref_mesh, [-1 for i in loss.shape], self.dist_context
961
            )
962 963 964

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
965 966 967 968
            assert (
                first_backward_op.type == "fill_constant"
                and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            )
969 970 971 972
            cast_loss_grad = main_block.create_var(
                name=unique_name.generate(tmp_name + "@GRAD"),
                shape=loss.shape,
                dtype=core.VarDesc.VarType.FP32,
973 974
                persistable=loss.persistable,
            )
975 976 977 978 979 980
            set_var_dist_attr(
                self.dist_context,
                cast_loss_grad,
                [-1 for i in loss.shape],
                ref_mesh,
            )
981 982 983

            pre_grad_name = first_backward_op.output_arg_names[0]
            first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
984
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
985 986 987 988
                first_backward_op,
                ref_mesh,
                [-1 for i in loss.shape],
                self.dist_context,
989
            )
990 991 992 993 994 995 996
            cast_grad_op = main_block._insert_op(
                loss_op_idx + 3,
                type='cast',
                inputs={'X': [cast_loss_grad]},
                outputs={'Out': [pre_grad_name]},
                attrs={
                    "in_dtype": core.VarDesc.VarType.FP32,
997 998
                    "out_dtype": _str_to_dtype(self.amp_dtype),
                    "op_role": OpRole.Backward,
999 1000
                },
            )
1001
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
1002 1003 1004 1005
                cast_grad_op,
                ref_mesh,
                [-1 for i in loss.shape],
                self.dist_context,
1006
            )
1007 1008
            loss_op = cast_op
            loss = cast_loss
1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020
        self._loss = loss
        main_block._sync_with_cpp()

    def _scale_loss(self):

        main_block = paddle.static.default_main_program().global_block()
        loss = self.get_attr("loss")
        assert loss is not None
        loss_op = loss.op
        loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
            loss_op
        )
J
JZ-LIANG 已提交
1021

1022 1023 1024 1025
        if (
            self.get_attr("use_dynamic_loss_scaling")
            or self.get_attr("init_loss_scaling") != 1.0
        ):
J
JZ-LIANG 已提交
1026 1027 1028 1029 1030

            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)

            # forward
            ref_mesh = loss_op_dist_attr.process_mesh
1031
            scaled_loss = main_block.create_var(
J
JZ-LIANG 已提交
1032 1033 1034
                name=unique_name.generate("scaled_loss"),
                shape=loss.shape,
                dtype=loss.dtype,
1035 1036
                persistable=loss.persistable,
            )
1037 1038 1039 1040 1041 1042
            set_var_dist_attr(
                self.dist_context,
                scaled_loss,
                [-1 for i in loss.shape],
                ref_mesh,
            )
J
JZ-LIANG 已提交
1043 1044 1045 1046

            elementwise_mul_op = main_block._insert_op(
                loss_op_idx + 1,
                type='elementwise_mul',
1047
                inputs={'X': [loss], 'Y': [self._loss_scaling]},
1048
                outputs={'Out': [scaled_loss]},
1049 1050
                attrs={
                    'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
1051 1052
                },
            )
1053
            loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward)
J
JZ-LIANG 已提交
1054
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
1055 1056 1057 1058
                elementwise_mul_op,
                ref_mesh,
                [-1 for i in loss.shape],
                self.dist_context,
1059
            )
J
JZ-LIANG 已提交
1060 1061 1062

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
1063 1064 1065 1066
            assert (
                first_backward_op.type == "fill_constant"
                and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            )
1067
            scaled_loss_grad = main_block.create_var(
J
JZ-LIANG 已提交
1068 1069 1070
                name=unique_name.generate("scaled_loss") + "@GRAD",
                shape=loss.shape,
                dtype=loss.dtype,
1071 1072 1073
                persistable=loss.persistable,
            )
            set_var_dist_attr(
1074 1075 1076 1077
                self.dist_context,
                scaled_loss_grad,
                [-1 for i in loss.shape],
                ref_mesh,
1078
            )
J
JZ-LIANG 已提交
1079
            pre_grad_name = first_backward_op.output_arg_names[0]
1080
            first_backward_op._rename_output(
1081
                pre_grad_name, scaled_loss_grad.name
1082
            )
1083
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
1084 1085 1086 1087
                first_backward_op,
                ref_mesh,
                [-1 for i in loss.shape],
                self.dist_context,
1088
            )
1089
            scaled_loss_grad.op = first_backward_op
J
JZ-LIANG 已提交
1090 1091 1092
            # FIXME(JZ-LIANG) a trick to insert backward op
            main_block._sync_with_cpp()
            elementwise_mul_grad_op_desc = main_block.desc._insert_op(
1093 1094
                loss_op_idx + 3
            )
J
JZ-LIANG 已提交
1095 1096
            elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
            elementwise_mul_grad_op_desc.set_input(
1097
                'Out@GRAD', [scaled_loss_grad.name]
1098
            )
J
JZ-LIANG 已提交
1099
            elementwise_mul_grad_op_desc.set_input('X', [loss.name])
1100 1101 1102
            elementwise_mul_grad_op_desc.set_input(
                'Y', [self._loss_scaling.name]
            )
J
JZ-LIANG 已提交
1103 1104
            elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
            elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
1105
            elementwise_mul_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
J
JZ-LIANG 已提交
1106
            elementwise_mul_grad_op_desc._set_attr('axis', -1)
1107
            elementwise_mul_grad_op = paddle.static.Operator(
1108 1109
                main_block, elementwise_mul_grad_op_desc
            )
J
JZ-LIANG 已提交
1110 1111 1112 1113 1114
            main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op)
            main_block._sync_with_cpp()
            elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
            assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
1115 1116 1117 1118
                elementwise_mul_grad_op,
                ref_mesh,
                [-1 for i in loss.shape],
                self.dist_context,
1119
            )
J
JZ-LIANG 已提交
1120
        else:
1121 1122
            scaled_loss = loss
        self._loss = scaled_loss
J
JZ-LIANG 已提交
1123 1124 1125 1126 1127 1128 1129
        main_block._sync_with_cpp()

    def _update_loss_scaling(self, grads, found_inf):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()

1130 1131 1132 1133 1134 1135
        check_variable_and_dtype(
            self._loss_scaling,
            "prev_loss_scaling",
            ['float32', 'float64'],
            "update_loss_scaling",
        )
J
JZ-LIANG 已提交
1136 1137
        check_type(grads, 'x', (tuple, list), 'update_loss_scaling')
        for e in grads:
1138 1139 1140
            check_variable_and_dtype(
                e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
            )
1141
            if e.dtype == core.VarDesc.VarType.FP16:
1142 1143 1144
                assert (
                    self._loss_scaling.dtype == core.VarDesc.VarType.FP32
                ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
1145
            else:
1146 1147 1148
                assert (
                    self._loss_scaling.dtype == e.dtype
                ), "The dtype of prev_loss_scaling should be equal to the dtype of x."
J
JZ-LIANG 已提交
1149 1150 1151 1152 1153 1154

        inputs = {
            'X': grads,
            'FoundInfinite': found_inf,
            'PrevLossScaling': self._loss_scaling,
            'InGoodSteps': self._num_good_steps,
1155
            'InBadSteps': self._num_bad_steps,
J
JZ-LIANG 已提交
1156 1157 1158 1159 1160 1161
        }

        outputs = {
            'Out': grads,
            'LossScaling': self._loss_scaling,
            'OutGoodSteps': self._num_good_steps,
1162
            'OutBadSteps': self._num_bad_steps,
J
JZ-LIANG 已提交
1163 1164 1165 1166 1167 1168 1169 1170
        }

        attrs = {
            'incr_every_n_steps': self.get_attr("incr_every_n_steps"),
            'decr_every_n_nan_or_inf': self.get_attr("decr_every_n_nan_or_inf"),
            'incr_ratio': self.get_attr("incr_ratio"),
            'decr_ratio': self.get_attr("decr_ratio"),
            'stop_update': self.get_attr("stop_update"),
1171
            'op_role': OpRole.Optimize,
J
JZ-LIANG 已提交
1172 1173
        }

1174 1175 1176 1177 1178 1179
        new_op = main_block.append_op(
            type='update_loss_scaling',
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
        )
J
JZ-LIANG 已提交
1180

1181 1182 1183 1184
        # Constructing dist attr from op_desc can
        # give all inputs and outputs default dist attrs
        new_op_dist_attr = OperatorDistAttr(new_op.desc)
        new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
Z
zhaoyingli 已提交
1185 1186 1187
        new_op_dist_attr.impl_idx = 0
        if len(world_process_group.ranks) > 1:
            new_op_dist_attr.impl_type = "update_loss_scaling"
J
JZ-LIANG 已提交
1188 1189 1190
        for g in grads:
            g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
            assert g_dist_attr is not None
1191 1192 1193 1194 1195 1196
            new_op_dist_attr.set_input_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
            new_op_dist_attr.set_output_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
J
JZ-LIANG 已提交
1197 1198 1199
        self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)

        main_block._sync_with_cpp()
1200 1201

    def get_loss(self):
1202
        # the amp might change the effective loss variable for network and
1203
        # therefore would affect the subsequent passes that rely on the loss.
1204
        # return the effective loss after amp pass.
1205 1206 1207 1208 1209

        if self._loss:
            return self._loss
        else:
            return self.get_attr("loss")