auto_parallel_amp.py 46.0 KB
Newer Older
J
JZ-LIANG 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
J
JZ-LIANG 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
J
JZ-LIANG 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
J
JZ-LIANG 已提交
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr
17 18 19
from paddle.distributed.auto_parallel.process_group import (
    get_world_process_group,
)
20 21 22
from paddle.distributed.auto_parallel.utils import (
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
    set_var_dist_attr,
23
)
24
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
25 26
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.framework import core
27 28 29 30
from paddle.static.amp.bf16.amp_utils import (
    AutoMixedPrecisionListsBF16,
    _is_in_fp32_varnames,
)
31
from paddle.static.amp.fp16_utils import (
32 33
    AutoMixedPrecisionLists,
    _is_in_black_varnames,
34 35
    _keep_fp32_input,
    _keep_fp32_output,
36
    _rename_arg,
37
    _valid_types,
38
    find_op_index,
39 40 41
    find_true_post_op,
    find_true_prev_op,
)
42
from paddle.utils import unique_name
43

44
from ..auto_parallel.process_mesh import ProcessMesh
45 46 47 48 49 50 51
from ..auto_parallel.utils import (
    is_backward_op,
    is_forward_op,
    is_loss_grad_op,
    is_loss_op,
    is_optimize_op,
)
52
from .pass_base import PassBase, register_pass
53

Z
zhaoyingli 已提交
54
world_process_group = get_world_process_group()
J
JZ-LIANG 已提交
55

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
__amp_skip_ops__ = [
    'create_py_reader',
    'create_double_buffer_reader',
    'cast',
    'while',
]


def _dtype_to_str(dtype):
    if dtype == core.VarDesc.VarType.FP16:
        return 'fp16'
    elif dtype == core.VarDesc.VarType.BF16:
        return 'bf16'
    else:
        return 'fp32'


def _str_to_dtype(dstr):
    if dstr == 'float16':
        return core.VarDesc.VarType.FP16
    elif dstr == 'bfloat16':
        return core.VarDesc.VarType.BF16
    else:
        return core.VarDesc.VarType.FP32


class AMPLists:
    def __init__(
        self,
        white_list=None,
        black_list=None,
        black_varnames=None,
        dtype="float16",
    ):
        self._amp_list = None
        if dtype == "float16":
            self._amp_list = AutoMixedPrecisionLists(
                set(white_list), set(black_list), set(black_varnames)
            )
        elif dtype == "bfloat16":
            self._amp_list = AutoMixedPrecisionListsBF16(
                set(white_list), set(black_list), set(black_varnames)
            )

        assert self._amp_list is not None
        self._dtype = dtype
        self._is_float16 = dtype == "float16"

    @property
    def white_list(self):
        if self._is_float16:
            return self._amp_list.white_list
        else:
            return self._amp_list.bf16_list

    @property
    def black_list(self):
        if self._is_float16:
            return self._amp_list.black_list
        else:
            return self._amp_list.fp32_list

    @property
    def gray_list(self):
        return self._amp_list.gray_list

    @property
    def black_varnames(self):
        if self._is_float16:
            return self._amp_list.black_varnames
        else:
            return self._amp_list.fp32_varnames

    @property
    def is_fp16(self):
        return self._is_float16

    @property
    def dtype(self):
        return self._dtype

    @property
    def amp_list(self):
        return self._amp_list

    def _is_in_black_fp32_varnames(self, op):
        if self._is_float16:
            return _is_in_black_varnames(op, self._amp_list)
        else:
            return _is_in_fp32_varnames(op, self._amp_list)

    def _op_keep_fp32_input(self, op, in_name):
        if self._is_float16:
            return _keep_fp32_input(op, in_name)
        else:
            if op.type in ['batch_norm', 'layer_norm']:
                return in_name != 'X'
            if op.type == 'fused_bn_add_activation':
                return in_name not in {'X', 'Z'}
            return False

    def _op_keep_fp32_output(self, op, out_name):
        if self._is_float16:
            return _keep_fp32_output(op, out_name)
        else:
            if op.type in [
                'batch_norm',
                'fused_bn_add_activation',
                'layer_norm',
            ]:
                return out_name != 'Y'
            return False

J
JZ-LIANG 已提交
169

170
class AMPState:
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    def __init__(self, program, amp_lists, amp_dtype, dist_context):
        self.program = program
        self.dist_context = dist_context
        self.amp_lists = amp_lists
        self.amp_dtype = amp_dtype
        self.grad_op_to_op_map = (
            dist_context.dist_op_context.grad_op_id_to_op_id
        )

        # op_id --> True/False. 'True' means that the current op is in fp16/bf16 mode.
        self._op_fp16_dict = {}
        # fwd_op_id --> {old_name: cast_name}
        self._var_name_dict = {}
        # out_var_name --> [op_ids]
        self.out_var_op_deps = {}
J
JZ-LIANG 已提交
186 187 188 189

    def _is_fp16_op(self, op_id):
        return self._op_fp16_dict.get(op_id, None)

190 191 192 193 194 195 196 197 198 199 200 201
    def build_state(self):
        is_train = False
        for block in self.program.blocks:
            for op in block.ops:
                # to record the inplace operation and their outputs
                for name in op.output_arg_names:
                    if name not in self.out_var_op_deps:
                        self.out_var_op_deps[name] = [op.desc.original_id()]
                    else:
                        self.out_var_op_deps[name].extend(
                            [op.desc.original_id()]
                        )
J
JZ-LIANG 已提交
202

203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
                if is_loss_grad_op(op):
                    is_train = True

                if op.type in __amp_skip_ops__:
                    continue

                if is_forward_op(op):
                    self._mark_black_white_ops(op, block.ops, block)
                elif is_backward_op(op):
                    if op.desc.original_id() in self.grad_op_to_op_map:
                        fwd_op_id = self.grad_op_to_op_map[
                            op.desc.original_id()
                        ]
                        assert fwd_op_id in self._op_fp16_dict, "{}".format(
                            str(op)
                        )
                        self._op_fp16_dict[
                            op.desc.original_id()
                        ] = self._is_fp16_op(fwd_op_id)
                elif is_optimize_op(op):
                    break

        # insert cast ops
        for block in self.program.blocks:
            self._cast_block(block)

        return is_train

    def _mark_black_white_ops(self, op, ops, block):

        # ernie inference trick
        if op.type == "assign" and "array_" in op.input_arg_names[0]:
            self._op_fp16_dict[op.desc.original_id()] = False
            return

        # If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
        if op.type == "assign":
            out_name = op.output_arg_names[0]
            if len(self.out_var_op_deps[out_name]) > 1:
                if not self._is_fp16_op(self.out_var_op_deps[out_name][0]):
243
                    self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
244
                else:
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
                    self._op_fp16_dict[op.desc.original_id()] = True
                return

        if (
            self.amp_lists.black_varnames is not None
            and self.amp_lists._is_in_black_fp32_varnames(op)
        ):
            self._op_fp16_dict[op.desc.original_id()] = False
            return
        if op.type in self.amp_lists.black_list:
            self._op_fp16_dict[op.desc.original_id()] = False
        elif op.type in self.amp_lists.white_list:
            self._op_fp16_dict[op.desc.original_id()] = True
        elif op.type in self.amp_lists.gray_list:
            is_black_op = False
            is_white_op = False
            for in_name in op.input_names:
                # if this op has inputs
                if in_name:
                    for in_var_name in op.input(in_name):
                        in_var = block._var_recursive(in_var_name)
                        # this in_var isn't the output of other op
                        if in_var.op is None:
                            continue
                        elif in_var.op is op:
                            prev_op = find_true_prev_op(ops, op, in_var_name)
                            if prev_op is None:
                                continue
                        else:
                            prev_op = in_var.op
                        # if it's one of inputs
                        if (
                            self._is_fp16_op(prev_op.desc.original_id())
                            is False
                            or prev_op.type in self.amp_lists.black_list
                        ):
                            is_black_op = True
                        elif (
                            self._is_fp16_op(prev_op.desc.original_id()) is True
                            or prev_op.type in self.amp_lists.white_list
                        ):
                            is_white_op = True
            if is_black_op:
288
                self._op_fp16_dict[op.desc.original_id()] = False
289 290 291 292 293 294 295 296
            elif is_white_op:
                self._op_fp16_dict[op.desc.original_id()] = True
            else:
                pass
        else:
            # For numerical safe, we apply fp32 computation on ops that
            # are not determined which list they should stay.
            self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
297

298
    def _cast_block(self, block):
J
JZ-LIANG 已提交
299
        idx = 0
300 301 302
        appended_grad_times = 0
        while idx < len(block.ops):
            op = block.ops[idx]
J
JZ-LIANG 已提交
303
            num_cast_ops = 0
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341

            if op.type in __amp_skip_ops__:
                idx += 1
                continue

            elif is_forward_op(op):
                if self._is_fp16_op(op.desc.original_id()) is False:
                    num_cast_ops = self._insert_cast_op_forward(
                        block,
                        op,
                        idx,
                        _str_to_dtype(self.amp_dtype),
                        core.VarDesc.VarType.FP32,
                        self.dist_context,
                    )
                elif self._is_fp16_op(op.desc.original_id()) is True:
                    if self.amp_dtype == "bfloat16":
                        if op.has_attr('use_mkldnn'):
                            op._set_attr('use_mkldnn', True)
                            op._set_attr('mkldnn_data_type', 'bfloat16')
                        elif (
                            op.has_attr('dtype')
                            and op.attr('dtype') == core.VarDesc.VarType.FP32
                        ):
                            op._set_attr('dtype', core.VarDesc.VarType.BF16)
                    num_cast_ops = self._insert_cast_op_forward(
                        block,
                        op,
                        idx,
                        core.VarDesc.VarType.FP32,
                        _str_to_dtype(self.amp_dtype),
                        self.dist_context,
                    )
            elif is_backward_op(op):
                # NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
                # which will affect the dist_op to insert allreduce_sum op.
                op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
                    op
342
                )
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
                if is_backward_op(op) and (
                    is_forward_op(block.ops[idx - 1])
                    or is_loss_op(block.ops[idx - 1])
                ):
                    if not op_dist_attr.is_recompute:
                        appended_grad_times += 1

                if op.desc.original_id() in self.grad_op_to_op_map:
                    if self._is_fp16_op(op.desc.original_id()) is False:  # fp32
                        num_cast_ops = self._insert_cast_op_backward(
                            block,
                            op,
                            idx,
                            _str_to_dtype(self.amp_dtype),
                            core.VarDesc.VarType.FP32,
                            self.dist_context,
                            appended_grad_times,
                        )
                    elif (
                        self._is_fp16_op(op.desc.original_id()) is True
                    ):  # fp16/bf16
                        if self.amp_dtype == "bfloat16":
                            if op.has_attr('use_mkldnn'):
                                op._set_attr('use_mkldnn', True)
                                op._set_attr('mkldnn_data_type', 'bfloat16')
                            elif (
                                op.has_attr('dtype')
                                and op.attr('dtype')
                                == core.VarDesc.VarType.FP32
                            ):
                                op._set_attr('dtype', core.VarDesc.VarType.BF16)
                        num_cast_ops = self._insert_cast_op_backward(
                            block,
                            op,
                            idx,
                            core.VarDesc.VarType.FP32,
                            _str_to_dtype(self.amp_dtype),
                            self.dist_context,
                            appended_grad_times,
                        )
                elif op.type == "sum":
                    # all inputs dtype of sum should be equal and output dtype should follow input
                    out_var_name = op.desc.output_arg_names()[0]
                    in_var_name = op.desc.input_arg_names()[0]
                    out_var = block.var(out_var_name)
                    in_var = block._find_var_recursive(in_var_name)
                    for in_var_name in op.input_arg_names:
                        assert (
                            in_var.dtype == block.var(in_var_name).dtype
                        ), "{}, {}, {}".format(
                            in_var, block.var(in_var_name), str(op)
                        )
                    out_var.desc.set_dtype(in_var.dtype)
                elif int(op.attr('op_role')) == 257:
                    pass
                else:
                    raise ValueError(
                        "'{}' op is not supported in the complete amp pass.".format(
                            op.type
                        )
                    )
J
JZ-LIANG 已提交
404
            idx += num_cast_ops + 1
405
        block._sync_with_cpp()
J
JZ-LIANG 已提交
406

407
    def _insert_cast_op_forward(
408
        self, block, op, idx, src_dtype, dst_dtype, dist_context
409
    ):
J
JZ-LIANG 已提交
410 411
        """
        only for forward cast
412
        modified from paddle.static.amp
J
JZ-LIANG 已提交
413 414
        """
        num_cast_ops = 0
415
        var_name_dict = {}
J
JZ-LIANG 已提交
416
        for in_name in op.input_names:
417 418 419
            if (
                src_dtype == core.VarDesc.VarType.FP32
                and self.amp_lists._op_keep_fp32_input(op, in_name)
420
            ):
J
JZ-LIANG 已提交
421 422
                continue
            for in_var_name in op.input(in_name):
423
                in_var = block._find_var_recursive(in_var_name)
J
JZ-LIANG 已提交
424 425 426
                if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
                    continue
                if in_var.dtype == src_dtype:
427 428 429
                    cast_name = (
                        in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
                    )
430
                    cast_var = block.vars.get(cast_name)
J
JZ-LIANG 已提交
431 432
                    var_name_dict[in_var.name] = cast_name
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
433 434
                        op
                    )
J
JZ-LIANG 已提交
435
                    assert consume_op_attr is not None
436
                    if cast_var is None or cast_var.dtype != dst_dtype:
J
JZ-LIANG 已提交
437 438 439
                        # NOTE we make the cast op and var's dist attr as the op that consume the
                        # cast var instead of the op which generates the var
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
440 441
                            in_var.name
                        )
J
JZ-LIANG 已提交
442 443 444
                        assert in_var_dist_attr is not None
                        ref_mesh = in_var_dist_attr.process_mesh
                        ref_mapping = in_var_dist_attr.dims_mapping
445
                        consume_op_attr.set_input_dist_attr(
446 447
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
448

449
                        cast_var = block.create_var(
J
JZ-LIANG 已提交
450 451 452
                            name=cast_name,
                            dtype=dst_dtype,
                            persistable=False,
453 454 455
                            stop_gradient=in_var.stop_gradient,
                        )
                        set_var_dist_attr(
456
                            dist_context, cast_var, ref_mapping, ref_mesh
457
                        )
J
JZ-LIANG 已提交
458

459 460 461
                        op_namescope = "/"
                        if op.has_attr('op_namescope'):
                            op_namescope = op.attr('op_namescope')
462
                        cast_op = block._insert_op_without_sync(
J
JZ-LIANG 已提交
463 464 465
                            idx,
                            type="cast",
                            inputs={"X": in_var},
466
                            outputs={"Out": cast_var},
J
JZ-LIANG 已提交
467 468
                            attrs={
                                "in_dtype": in_var.dtype,
469
                                "out_dtype": cast_var.dtype,
470 471
                            },
                        )
472 473 474
                        cast_op._set_attr(
                            'op_namescope', op_namescope
                        )  # for recompute
J
JZ-LIANG 已提交
475
                        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
476 477
                            cast_op, ref_mesh, ref_mapping, dist_context
                        )
J
JZ-LIANG 已提交
478 479 480
                        num_cast_ops += 1
                    else:
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
481 482
                            in_var.name
                        )
483
                        consume_op_attr.set_input_dist_attr(
484 485
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
486 487 488 489
                    _rename_arg(op, in_var.name, cast_name)
                else:
                    if op.has_attr('in_dtype'):
                        op._set_attr('in_dtype', dst_dtype)
490
        self._var_name_dict[op.desc.original_id()] = var_name_dict
J
JZ-LIANG 已提交
491

492 493
        if (
            src_dtype == core.VarDesc.VarType.FP32
494
            and dst_dtype == _str_to_dtype(self.amp_dtype)
495
        ):
J
JZ-LIANG 已提交
496
            for out_name in op.output_names:
497
                if self.amp_lists._op_keep_fp32_output(op, out_name):
J
JZ-LIANG 已提交
498 499
                    continue
                for out_var_name in op.output(out_name):
500
                    out_var = block._var_recursive(out_var_name)
J
JZ-LIANG 已提交
501 502 503
                    if out_var.type not in _valid_types:
                        continue
                    if out_var.dtype == core.VarDesc.VarType.FP32:
504
                        out_var.desc.set_dtype(_str_to_dtype(self.amp_dtype))
J
JZ-LIANG 已提交
505
                        if op.has_attr('out_dtype'):
506 507 508
                            op._set_attr(
                                'out_dtype', _str_to_dtype(self.amp_dtype)
                            )
J
JZ-LIANG 已提交
509 510
        return num_cast_ops

511 512
    def _insert_cast_op_backward(
        self,
513 514
        block,
        op,
515 516 517 518 519 520 521
        idx,
        src_dtype,
        dst_dtype,
        dist_context,
        appended_grad_times,
    ):
        """only for backward cast"""
J
JZ-LIANG 已提交
522 523 524 525 526 527 528 529 530 531 532 533 534 535

        def _keep_fp32_input(op, in_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return in_name not in {'X', 'Y@GRAD'}
            return False

        def _keep_fp32_output(op, out_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return out_name != 'X@GRAD'
            return False

        num_cast_ops = 0
536
        original_id = op.desc.original_id()
J
JZ-LIANG 已提交
537
        dist_op_context = dist_context.dist_op_context
538
        fwd_op_id = self.grad_op_to_op_map[original_id]
J
JZ-LIANG 已提交
539

540
        for in_name in op.input_names:
J
JZ-LIANG 已提交
541
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
542
                op, in_name
543
            ):
544 545
                for in_var_name in op.input(in_name):
                    in_var = block._var_recursive(in_var_name)
J
JZ-LIANG 已提交
546 547 548
                    assert in_var.dtype == core.VarDesc.VarType.FP32
                continue

549 550
            for in_var_name in op.input(in_name):
                in_var = block._var_recursive(in_var_name)
J
JZ-LIANG 已提交
551 552
                if in_var.dtype == src_dtype:
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
553
                        op
554
                    )
J
JZ-LIANG 已提交
555 556 557 558
                    if in_var_name in self._var_name_dict[fwd_op_id]:
                        # NOTE: if in_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr.
                        cast_name = self._var_name_dict[fwd_op_id][in_var_name]
559
                        op.desc._rename_input(in_var_name, cast_name)
J
JZ-LIANG 已提交
560
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
561 562
                            in_var_name
                        )
563
                        consume_op_attr.set_input_dist_attr(
564 565
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
566
                    else:
567 568 569
                        assert (
                            in_var.dtype == dst_dtype
                        ), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
570
                            op.type,
571 572 573
                            in_name,
                            dst_dtype,
                            in_var.dtype,
574
                            str(op),
575
                        )
J
JZ-LIANG 已提交
576

577
        for out_name in op.output_names:
J
JZ-LIANG 已提交
578
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
579
                op, out_name
580
            ):
581 582
                for out_var_name in op.output(out_name):
                    out_var = block._var_recursive(out_var_name)
J
JZ-LIANG 已提交
583 584 585
                    assert out_var.dtype == core.VarDesc.VarType.FP32
                continue

586 587
            for out_var_name in op.output(out_name):
                out_var = block._var_recursive(out_var_name)
588
                out_var_name_prefix = out_var_name[: out_var_name.find("@")]
589
                fwd_var = block._var_recursive(out_var_name_prefix)
J
JZ-LIANG 已提交
590 591 592 593 594 595 596 597 598
                # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
                if out_var.dtype != fwd_var.dtype:
                    out_var.desc.set_dtype(fwd_var.dtype)

                if out_var.dtype == src_dtype:
                    if out_var_name_prefix in self._var_name_dict[fwd_op_id]:
                        # NOTE: if out_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr, then we insert cast op to
                        # convert the cast_var to original dtype
599
                        consume_op_attr = (
600
                            dist_context.get_op_dist_attr_for_program(op)
601
                        )
J
JZ-LIANG 已提交
602
                        fwd_cast_name = self._var_name_dict[fwd_op_id][
603 604
                            out_var_name_prefix
                        ]
605 606
                        suffix = ""
                        if "@RENAME" in out_var_name:
607 608 609
                            suffix = out_var_name[
                                out_var_name.find("@RENAME") :
                            ]
610
                        cast_name = fwd_cast_name + "@GRAD" + suffix
611
                        cast_var = block.vars.get(cast_name)
J
JZ-LIANG 已提交
612
                        if cast_var is None or cast_var.dtype != dst_dtype:
613
                            op.desc._rename_output(out_var_name, cast_name)
614 615 616 617 618
                            out_var_dist_attr = (
                                consume_op_attr.get_output_dist_attr(
                                    out_var_name
                                )
                            )
J
JZ-LIANG 已提交
619 620 621
                            ref_mesh = out_var_dist_attr.process_mesh
                            ref_mapping = out_var_dist_attr.dims_mapping
                            consume_op_attr.set_output_dist_attr(
622 623
                                cast_name, out_var_dist_attr
                            )
J
JZ-LIANG 已提交
624
                            assert ref_mapping is not None
625
                            cast_var = block.create_var(
J
JZ-LIANG 已提交
626 627 628 629
                                name=cast_name,
                                shape=out_var.shape,
                                dtype=dst_dtype,
                                persistable=False,
630 631 632 633 634
                                stop_gradient=out_var.stop_gradient,
                            )
                            set_var_dist_attr(
                                dist_context, cast_var, ref_mapping, ref_mesh
                            )
635
                            dist_op_context.grad_var_to_var[
636 637
                                appended_grad_times
                            ][cast_name] = fwd_cast_name
J
JZ-LIANG 已提交
638

639
                            cast_op = block._insert_op(
J
JZ-LIANG 已提交
640 641 642 643 644 645 646
                                idx + 1,
                                type="cast",
                                inputs={"X": cast_var},
                                outputs={"Out": out_var},
                                attrs={
                                    "in_dtype": cast_var.dtype,
                                    "out_dtype": out_var.dtype,
647 648 649
                                    "op_role": OpRole.Backward,
                                },
                            )
J
JZ-LIANG 已提交
650 651 652 653
                            cast_op._remove_attr("op_role_var")
                            cast_op._remove_attr("op_namescope")
                            cast_op._remove_attr("with_quant_attr")
                            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
654 655
                                cast_op, ref_mesh, ref_mapping, dist_context
                            )
J
JZ-LIANG 已提交
656 657 658 659 660 661 662 663 664 665
                            num_cast_ops += 1
                else:
                    assert out_var.dtype == dst_dtype

        return num_cast_ops


@register_pass("auto_parallel_amp")
class AMPPass(PassBase):
    def __init__(self):
666
        super().__init__()
667
        self.set_attr("dtype", "")  # fp16/bf16
J
JZ-LIANG 已提交
668 669 670 671 672 673 674 675 676 677 678
        self.set_attr("loss", None)
        self.set_attr("dist_context", None)
        self.set_attr("custom_white_list", None)
        self.set_attr("custom_black_list", None)
        self.set_attr("custom_black_varnames", None)
        self.set_attr("init_loss_scaling", 32768.0)
        self.set_attr("incr_every_n_steps", 1000)
        self.set_attr("decr_every_n_nan_or_inf", 2)
        self.set_attr("incr_ratio", 2.0)
        self.set_attr("decr_ratio", 0.8)
        self.set_attr("use_dynamic_loss_scaling", False)
679
        self.set_attr("input_data", [])
J
JZ-LIANG 已提交
680
        self.set_attr("params_grads", [])
681
        self.set_attr("dtype", "")  # fp16/bf16
682
        self._loss = None
J
JZ-LIANG 已提交
683 684 685 686 687
        self._loss_scaling = None
        self._num_good_steps = None
        self._num_bad_steps = None

    def _check_self(self):
688 689
        if self.get_attr("dtype") not in ["float16", "bfloat16"]:
            return False
J
JZ-LIANG 已提交
690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706
        if self.get_attr("init_loss_scaling") < 0:
            return False
        if self.get_attr("incr_every_n_steps") < 0:
            return False
        if self.get_attr("decr_every_n_nan_or_inf") < 0:
            return False
        if self.get_attr("incr_ratio") < 0:
            return False
        if self.get_attr("decr_ratio") < 0:
            return False
        if self.get_attr("dist_context") is None:
            return False
        return True

    def _check_conflict(self, other_pass):
        return True

707 708
    # NOTE: why AMPBackwardPass can override apply_single_impl instead of
    # apply_impl? AMP is an optimization pass for serial program,
J
JZ-LIANG 已提交
709 710 711
    # in distributed scenario, all ranks should have the same modification.
    def _apply_single_impl(self, main_program, startup_program, context):
        self.dist_context = self.get_attr("dist_context")
712 713
        self.params_grads = self.get_attr("params_grads")
        self.amp_dtype = self.get_attr("dtype")
J
JZ-LIANG 已提交
714

715
        amp_lists = AMPLists(
J
JZ-LIANG 已提交
716 717
            set(self.get_attr("custom_white_list")),
            set(self.get_attr("custom_black_list")),
718
            set(self.get_attr("custom_black_varnames")),
719
            self.amp_dtype,
720
        )
J
JZ-LIANG 已提交
721 722

        with paddle.static.program_guard(main_program, startup_program):
723 724 725 726
            amp_state = AMPState(
                main_program, amp_lists, self.amp_dtype, self.dist_context
            )
            is_train = amp_state.build_state()
Z
zhaoyingli 已提交
727

728 729 730
            if is_train:
                self._update_backward_cast_ops()
                self._cast_loss()
Z
zhaoyingli 已提交
731

732
            if is_train and self.amp_dtype == "float16":
Z
zhaoyingli 已提交
733 734
                self._init_amp_var()
                self._scale_loss()
735 736 737 738
                if (
                    self.get_attr("use_dynamic_loss_scaling")
                    or self.get_attr("init_loss_scaling") != 1.0
                ):
739
                    grads, found_inf = self._check_and_update_gradient()
Z
zhaoyingli 已提交
740 741 742

                if self.get_attr("use_dynamic_loss_scaling"):
                    self._update_loss_scaling(grads, found_inf)
J
JZ-LIANG 已提交
743

744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764
    def _update_backward_cast_ops(self):
        """
        move param grad cast to the end of backward segment
        in order to enabel fp16 allreduce
        """
        # TODO filter optimize ops in future

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()

        for p, g in self.params_grads:
            op = g.op
            if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
                if int(op.attr('op_role')) == int(
                    OpRole.Backward
                ) and op.has_attr('op_role_var'):
                    op._remove_attr("op_role_var")

                post_ops = find_true_post_op(main_block.ops, op, g.name)
                if post_ops:
                    raise ValueError(
765
                        f"The cast op {op}'s output should not be"
766
                        "used by a non-optimize op, however, it"
767
                        f"is used by {post_ops[0]}"
768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808
                    )

                if op == main_block.ops[-1]:
                    continue

                # add new op in the python and cpp at the same time
                new_op_desc = main_block.desc.append_op()
                new_op_desc.copy_from(op.desc)
                new_op = paddle.static.Operator(
                    block=main_block,
                    desc=new_op_desc,
                    type=None,
                    inputs=None,
                    outputs=None,
                    attrs=None,
                )
                main_block.ops.append(new_op)

                # dist attr
                param_dist_attr = (
                    self.dist_context.get_tensor_dist_attr_for_program(p)
                )
                output_dist_attr = (
                    self.dist_context.get_tensor_dist_attr_for_program(
                        main_block.var(op.output_arg_names[0])
                    )
                )
                assert param_dist_attr is not None
                assert output_dist_attr is not None
                naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                    new_op,
                    param_dist_attr.process_mesh,
                    param_dist_attr.dims_mapping,
                    self.dist_context,
                )

                output_dist_attr.process_mesh = param_dist_attr.process_mesh
                output_dist_attr.dims_mapping = param_dist_attr.dims_mapping

                op_idx = find_op_index(main_block.desc, op.desc)
                if op_idx == -1:
809
                    raise ValueError(f"The op {op} is not in program")
810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871
                main_block._remove_op(op_idx, sync=False)

        main_block._sync_with_cpp()

    def _check_and_update_gradient(self):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()

        grads = [g for _, g in self.params_grads]
        check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
        for e in grads:
            check_variable_and_dtype(
                e,
                "x",
                ['float16', 'float32', 'float64'],
                'check_finite_and_unscale',
            )

        found_inf = main_block.create_var(
            name=unique_name.generate_with_ignorable_key(
                ".".join(['find_infinite_scale', 'tmp'])
            ),
            shape=[1],
            dtype='bool',
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=False,
        )
        set_var_dist_attr(
            self.dist_context, found_inf, [-1], world_process_group.ranks
        )

        inputs = {'X': grads, 'Scale': self._loss_scaling}
        outputs = {'Out': grads, 'FoundInfinite': found_inf}
        attrs = {'op_role': OpRole.Optimize}
        new_op = main_block.append_op(
            type='check_finite_and_unscale',
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
        )

        # Constructing dist attr from op_desc can
        # give all inputs and outputs default dist attrs
        new_op_dist_attr = OperatorDistAttr(new_op.desc)
        new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
        new_op_dist_attr.impl_idx = 0
        if len(world_process_group.ranks) > 1:
            new_op_dist_attr.impl_type = "check_finite_and_unscale"
        for g in grads:
            g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
            assert g_dist_attr is not None
            new_op_dist_attr.set_input_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
            new_op_dist_attr.set_output_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
        self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
        return grads, found_inf

J
JZ-LIANG 已提交
872 873 874 875 876 877
    def _init_amp_var(self):
        self._loss_scaling = paddle.static.create_global_var(
            name=unique_name.generate("loss_scaling"),
            shape=[1],
            value=self.get_attr("init_loss_scaling"),
            dtype='float32',
878 879 880 881 882 883 884 885
            persistable=True,
        )
        set_var_dist_attr(
            self.dist_context,
            self._loss_scaling,
            [-1],
            world_process_group.ranks,
        )
J
JZ-LIANG 已提交
886 887 888 889 890 891 892

        if self.get_attr("use_dynamic_loss_scaling"):
            self._num_good_steps = paddle.static.create_global_var(
                name=unique_name.generate("num_good_steps"),
                shape=[1],
                value=0,
                dtype='int32',
893 894 895 896 897 898 899 900
                persistable=True,
            )
            set_var_dist_attr(
                self.dist_context,
                self._num_good_steps,
                [-1],
                world_process_group.ranks,
            )
J
JZ-LIANG 已提交
901 902 903 904 905 906

            self._num_bad_steps = paddle.static.create_global_var(
                name=unique_name.generate("num_bad_steps"),
                shape=[1],
                value=0,
                dtype='int32',
907 908 909 910 911 912 913 914
                persistable=True,
            )
            set_var_dist_attr(
                self.dist_context,
                self._num_bad_steps,
                [-1],
                world_process_group.ranks,
            )
J
JZ-LIANG 已提交
915

916
    def _cast_loss(self):
J
JZ-LIANG 已提交
917 918 919

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()
920

J
JZ-LIANG 已提交
921 922 923 924
        loss = self.get_attr("loss")
        assert loss is not None
        loss_op = loss.op
        loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
925 926
            loss_op
        )
J
JZ-LIANG 已提交
927 928

        if loss.dtype != core.VarDesc.VarType.FP32:
929 930

            tmp_name = unique_name.generate(loss.name + ".cast_fp32")
931 932 933
            cast_loss = main_block.create_var(
                name=tmp_name, dtype=core.VarDesc.VarType.FP32
            )
934
            loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
935 936
                loss
            )
937
            ref_mesh = loss_op_dist_attr.process_mesh
938
            self.dist_context.set_tensor_dist_attr_for_program(
939 940
                cast_loss, loss_dist_attr
            )
941

942
            # forward
943 944 945 946 947 948 949 950 951
            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
            cast_op = main_block._insert_op(
                loss_op_idx + 1,
                type='cast',
                inputs={'X': [loss]},
                outputs={'Out': [cast_loss]},
                attrs={
                    "in_dtype": loss.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
952
                    "op_role": loss_op.all_attrs()[OP_ROLE_KEY],
953 954
                },
            )
955

956
            loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward)
957
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
958
                cast_op, ref_mesh, [-1 for i in loss.shape], self.dist_context
959
            )
960 961 962

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
963 964 965 966
            assert (
                first_backward_op.type == "fill_constant"
                and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            )
967 968 969 970
            cast_loss_grad = main_block.create_var(
                name=unique_name.generate(tmp_name + "@GRAD"),
                shape=loss.shape,
                dtype=core.VarDesc.VarType.FP32,
971 972
                persistable=loss.persistable,
            )
973 974 975 976 977 978
            set_var_dist_attr(
                self.dist_context,
                cast_loss_grad,
                [-1 for i in loss.shape],
                ref_mesh,
            )
979 980 981

            pre_grad_name = first_backward_op.output_arg_names[0]
            first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
982
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
983 984 985 986
                first_backward_op,
                ref_mesh,
                [-1 for i in loss.shape],
                self.dist_context,
987
            )
988 989 990 991 992 993 994
            cast_grad_op = main_block._insert_op(
                loss_op_idx + 3,
                type='cast',
                inputs={'X': [cast_loss_grad]},
                outputs={'Out': [pre_grad_name]},
                attrs={
                    "in_dtype": core.VarDesc.VarType.FP32,
995 996
                    "out_dtype": _str_to_dtype(self.amp_dtype),
                    "op_role": OpRole.Backward,
997 998
                },
            )
999
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
1000 1001 1002 1003
                cast_grad_op,
                ref_mesh,
                [-1 for i in loss.shape],
                self.dist_context,
1004
            )
1005 1006
            loss_op = cast_op
            loss = cast_loss
1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018
        self._loss = loss
        main_block._sync_with_cpp()

    def _scale_loss(self):

        main_block = paddle.static.default_main_program().global_block()
        loss = self.get_attr("loss")
        assert loss is not None
        loss_op = loss.op
        loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
            loss_op
        )
J
JZ-LIANG 已提交
1019

1020 1021 1022 1023
        if (
            self.get_attr("use_dynamic_loss_scaling")
            or self.get_attr("init_loss_scaling") != 1.0
        ):
J
JZ-LIANG 已提交
1024 1025 1026 1027 1028

            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)

            # forward
            ref_mesh = loss_op_dist_attr.process_mesh
1029
            scaled_loss = main_block.create_var(
J
JZ-LIANG 已提交
1030 1031 1032
                name=unique_name.generate("scaled_loss"),
                shape=loss.shape,
                dtype=loss.dtype,
1033 1034
                persistable=loss.persistable,
            )
1035 1036 1037 1038 1039 1040
            set_var_dist_attr(
                self.dist_context,
                scaled_loss,
                [-1 for i in loss.shape],
                ref_mesh,
            )
J
JZ-LIANG 已提交
1041 1042 1043 1044

            elementwise_mul_op = main_block._insert_op(
                loss_op_idx + 1,
                type='elementwise_mul',
1045
                inputs={'X': [loss], 'Y': [self._loss_scaling]},
1046
                outputs={'Out': [scaled_loss]},
1047 1048
                attrs={
                    'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
1049 1050
                },
            )
1051
            loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward)
J
JZ-LIANG 已提交
1052
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
1053 1054 1055 1056
                elementwise_mul_op,
                ref_mesh,
                [-1 for i in loss.shape],
                self.dist_context,
1057
            )
J
JZ-LIANG 已提交
1058 1059 1060

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
1061 1062 1063 1064
            assert (
                first_backward_op.type == "fill_constant"
                and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            )
1065
            scaled_loss_grad = main_block.create_var(
J
JZ-LIANG 已提交
1066 1067 1068
                name=unique_name.generate("scaled_loss") + "@GRAD",
                shape=loss.shape,
                dtype=loss.dtype,
1069 1070 1071
                persistable=loss.persistable,
            )
            set_var_dist_attr(
1072 1073 1074 1075
                self.dist_context,
                scaled_loss_grad,
                [-1 for i in loss.shape],
                ref_mesh,
1076
            )
J
JZ-LIANG 已提交
1077
            pre_grad_name = first_backward_op.output_arg_names[0]
1078
            first_backward_op._rename_output(
1079
                pre_grad_name, scaled_loss_grad.name
1080
            )
1081
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
1082 1083 1084 1085
                first_backward_op,
                ref_mesh,
                [-1 for i in loss.shape],
                self.dist_context,
1086
            )
1087
            scaled_loss_grad.op = first_backward_op
J
JZ-LIANG 已提交
1088 1089 1090
            # FIXME(JZ-LIANG) a trick to insert backward op
            main_block._sync_with_cpp()
            elementwise_mul_grad_op_desc = main_block.desc._insert_op(
1091 1092
                loss_op_idx + 3
            )
J
JZ-LIANG 已提交
1093 1094
            elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
            elementwise_mul_grad_op_desc.set_input(
1095
                'Out@GRAD', [scaled_loss_grad.name]
1096
            )
J
JZ-LIANG 已提交
1097
            elementwise_mul_grad_op_desc.set_input('X', [loss.name])
1098 1099 1100
            elementwise_mul_grad_op_desc.set_input(
                'Y', [self._loss_scaling.name]
            )
J
JZ-LIANG 已提交
1101 1102
            elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
            elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
1103
            elementwise_mul_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
J
JZ-LIANG 已提交
1104
            elementwise_mul_grad_op_desc._set_attr('axis', -1)
1105
            elementwise_mul_grad_op = paddle.static.Operator(
1106 1107
                main_block, elementwise_mul_grad_op_desc
            )
J
JZ-LIANG 已提交
1108 1109 1110 1111 1112
            main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op)
            main_block._sync_with_cpp()
            elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
            assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
1113 1114 1115 1116
                elementwise_mul_grad_op,
                ref_mesh,
                [-1 for i in loss.shape],
                self.dist_context,
1117
            )
J
JZ-LIANG 已提交
1118
        else:
1119 1120
            scaled_loss = loss
        self._loss = scaled_loss
J
JZ-LIANG 已提交
1121 1122 1123 1124 1125 1126 1127
        main_block._sync_with_cpp()

    def _update_loss_scaling(self, grads, found_inf):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()

1128 1129 1130 1131 1132 1133
        check_variable_and_dtype(
            self._loss_scaling,
            "prev_loss_scaling",
            ['float32', 'float64'],
            "update_loss_scaling",
        )
J
JZ-LIANG 已提交
1134 1135
        check_type(grads, 'x', (tuple, list), 'update_loss_scaling')
        for e in grads:
1136 1137 1138
            check_variable_and_dtype(
                e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
            )
1139
            if e.dtype == core.VarDesc.VarType.FP16:
1140 1141 1142
                assert (
                    self._loss_scaling.dtype == core.VarDesc.VarType.FP32
                ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
1143
            else:
1144 1145 1146
                assert (
                    self._loss_scaling.dtype == e.dtype
                ), "The dtype of prev_loss_scaling should be equal to the dtype of x."
J
JZ-LIANG 已提交
1147 1148 1149 1150 1151 1152

        inputs = {
            'X': grads,
            'FoundInfinite': found_inf,
            'PrevLossScaling': self._loss_scaling,
            'InGoodSteps': self._num_good_steps,
1153
            'InBadSteps': self._num_bad_steps,
J
JZ-LIANG 已提交
1154 1155 1156 1157 1158 1159
        }

        outputs = {
            'Out': grads,
            'LossScaling': self._loss_scaling,
            'OutGoodSteps': self._num_good_steps,
1160
            'OutBadSteps': self._num_bad_steps,
J
JZ-LIANG 已提交
1161 1162 1163 1164 1165 1166 1167 1168
        }

        attrs = {
            'incr_every_n_steps': self.get_attr("incr_every_n_steps"),
            'decr_every_n_nan_or_inf': self.get_attr("decr_every_n_nan_or_inf"),
            'incr_ratio': self.get_attr("incr_ratio"),
            'decr_ratio': self.get_attr("decr_ratio"),
            'stop_update': self.get_attr("stop_update"),
1169
            'op_role': OpRole.Optimize,
J
JZ-LIANG 已提交
1170 1171
        }

1172 1173 1174 1175 1176 1177
        new_op = main_block.append_op(
            type='update_loss_scaling',
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
        )
J
JZ-LIANG 已提交
1178

1179 1180 1181 1182
        # Constructing dist attr from op_desc can
        # give all inputs and outputs default dist attrs
        new_op_dist_attr = OperatorDistAttr(new_op.desc)
        new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
Z
zhaoyingli 已提交
1183 1184 1185
        new_op_dist_attr.impl_idx = 0
        if len(world_process_group.ranks) > 1:
            new_op_dist_attr.impl_type = "update_loss_scaling"
J
JZ-LIANG 已提交
1186 1187 1188
        for g in grads:
            g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
            assert g_dist_attr is not None
1189 1190 1191 1192 1193 1194
            new_op_dist_attr.set_input_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
            new_op_dist_attr.set_output_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
J
JZ-LIANG 已提交
1195 1196 1197
        self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)

        main_block._sync_with_cpp()
1198 1199

    def get_loss(self):
1200
        # the amp might change the effective loss variable for network and
1201
        # therefore would affect the subsequent passes that rely on the loss.
1202
        # return the effective loss after amp pass.
1203 1204 1205 1206 1207

        if self._loss:
            return self._loss
        else:
            return self.get_attr("loss")