auto_parallel_amp.py 39.3 KB
Newer Older
J
JZ-LIANG 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
J
JZ-LIANG 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
J
JZ-LIANG 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
J
JZ-LIANG 已提交
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16 17
from paddle.distributed.auto_parallel.dist_attribute import (
    OperatorDistributedAttribute,
18 19 20 21
)
from paddle.distributed.auto_parallel.process_group import (
    get_world_process_group,
)
22 23 24 25
from paddle.distributed.auto_parallel.utils import (
    get_loss_op,
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
    set_var_dist_attr,
26
)
27 28
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid import unique_name
29
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
30 31 32
    AutoMixedPrecisionLists,
    _dtype_to_str,
    _is_in_black_varnames,
33 34
    _keep_fp32_input,
    _keep_fp32_output,
35
    _rename_arg,
36
    _valid_types,
37
    find_op_index,
38 39 40
    find_true_post_op,
    find_true_prev_op,
)
41 42 43 44 45
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.framework import core

from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op
from .pass_base import PassBase, register_pass
46

Z
zhaoyingli 已提交
47
world_process_group = get_world_process_group()
J
JZ-LIANG 已提交
48 49


50
class AMPState:
J
JZ-LIANG 已提交
51 52
    def __init__(self, block):
        self._block = block
53 54 55
        self._op_fp16_dict = (
            {}
        )  # op_id --> True/False. 'True' means that the current op is in fp16 mode.
J
JZ-LIANG 已提交
56
        self._var_name_dict = {}  # fwd_op_id --> {old_name: cast_name}
Z
zhaoyingli 已提交
57
        self.is_train = False
J
JZ-LIANG 已提交
58 59 60 61

    def _is_fp16_op(self, op_id):
        return self._op_fp16_dict.get(op_id, None)

Z
zhaoyingli 已提交
62
    def _build_state(self, amp_lists, dist_context):
J
JZ-LIANG 已提交
63 64 65
        ops = self._block.ops
        dist_op_context = dist_context.dist_op_context
        for op in ops:
Z
zhaoyingli 已提交
66 67 68
            if int(op.attr('op_role')) == 257:
                self.is_train = True

J
JZ-LIANG 已提交
69 70 71
            if int(op.attr('op_role')) == int(OpRole.Forward):
                self._mark_black_white_ops(amp_lists)
            elif int(op.attr('op_role')) == int(OpRole.Backward):
72 73
                if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
                    fwd_op_id = dist_op_context.grad_op_id_to_op_id[
74 75
                        op.desc.original_id()
                    ]
76
                    if self._is_fp16_op(fwd_op_id) is True:
77
                        self._op_fp16_dict[op.desc.original_id()] = True
78
                    elif self._is_fp16_op(fwd_op_id) is False:
79
                        self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
80 81 82
            elif int(op.attr('op_role')) == int(OpRole.Optimize):
                break

Z
zhaoyingli 已提交
83 84
        return self.is_train

J
JZ-LIANG 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97
    def _mark_black_white_ops(self, amp_lists):
        """
        this function is modified from paddle.fluid.contrib.mixed_precision
        """
        self._block._sync_with_cpp()
        ops = self._block.ops

        for op in ops:
            if int(op.attr('op_role')) == int(OpRole.Backward):
                break
            if op.type == 'create_py_reader' or op.type == 'read':
                continue
            if amp_lists.black_varnames is not None and _is_in_black_varnames(
98 99
                op, amp_lists
            ):
100
                self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
101 102
                continue
            if op.type in amp_lists.black_list:
103
                self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
104
            elif op.type in amp_lists.white_list:
105
                self._op_fp16_dict[op.desc.original_id()] = True
J
JZ-LIANG 已提交
106 107 108 109 110 111 112 113 114 115 116 117
            elif op.type in amp_lists.gray_list:
                is_black_op = False
                is_white_op = False
                for in_name in op.input_names:
                    # if this op has inputs
                    if in_name:
                        for in_var_name in op.input(in_name):
                            in_var = self._block.var(in_var_name)
                            # this in_var isn't the output of other op
                            if in_var.op is None:
                                continue
                            elif in_var.op is op:
118
                                prev_op = find_true_prev_op(
119 120
                                    ops, op, in_var_name
                                )
J
JZ-LIANG 已提交
121 122 123 124 125
                                if prev_op is None:
                                    continue
                            else:
                                prev_op = in_var.op
                            # if it's one of inputs
126 127
                            if (
                                self._is_fp16_op(prev_op.desc.original_id())
128
                                is False
129 130
                                or prev_op.type in amp_lists.black_list
                            ):
J
JZ-LIANG 已提交
131
                                is_black_op = True
132 133
                            elif (
                                self._is_fp16_op(prev_op.desc.original_id())
134
                                is True
135 136
                                or prev_op.type in amp_lists.white_list
                            ):
J
JZ-LIANG 已提交
137 138
                                is_white_op = True
                if is_black_op:
139
                    self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
140
                elif is_white_op:
141
                    self._op_fp16_dict[op.desc.original_id()] = True
J
JZ-LIANG 已提交
142 143 144 145 146
                else:
                    pass
            else:
                # For numerical safe, we apply fp32 computation on ops that
                # are not determined which list they should stay.
147
                self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
148 149 150 151 152 153 154 155 156

    def cast_forward_program(self, dist_context):
        ops = self._block.ops
        idx = 0
        while idx < len(ops):
            op = ops[idx]
            num_cast_ops = 0
            if int(op.attr('op_role')) == int(OpRole.Backward):
                break
157
            if self._is_fp16_op(op.desc.original_id()) is False:
J
JZ-LIANG 已提交
158
                num_cast_ops = self._insert_cast_op_forward(
159 160 161 162 163 164
                    op,
                    idx,
                    core.VarDesc.VarType.FP16,
                    core.VarDesc.VarType.FP32,
                    dist_context,
                )
165
            elif self._is_fp16_op(op.desc.original_id()) is True:
J
JZ-LIANG 已提交
166
                num_cast_ops = self._insert_cast_op_forward(
167 168 169 170 171 172
                    op,
                    idx,
                    core.VarDesc.VarType.FP32,
                    core.VarDesc.VarType.FP16,
                    dist_context,
                )
J
JZ-LIANG 已提交
173 174 175 176 177
            else:
                pass
            idx += num_cast_ops + 1
        self._block._sync_with_cpp()

178 179 180
    def _insert_cast_op_forward(
        self, op, idx, src_dtype, dst_dtype, dist_context
    ):
J
JZ-LIANG 已提交
181 182 183 184 185
        """
        only for forward cast
        modified from paddle.fluid.contrib.mixed_precision
        """
        num_cast_ops = 0
186
        var_name_dict = {}
J
JZ-LIANG 已提交
187 188
        for in_name in op.input_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
189 190
                op, in_name
            ):
J
JZ-LIANG 已提交
191 192 193 194 195 196
                continue
            for in_var_name in op.input(in_name):
                in_var = self._block._find_var_recursive(in_var_name)
                if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
                    continue
                if in_var.dtype == src_dtype:
197 198 199
                    cast_name = (
                        in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
                    )
J
JZ-LIANG 已提交
200 201 202
                    out_var = self._block.vars.get(cast_name)
                    var_name_dict[in_var.name] = cast_name
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
203 204
                        op
                    )
J
JZ-LIANG 已提交
205 206 207 208 209
                    assert consume_op_attr is not None
                    if out_var is None or out_var.dtype != dst_dtype:
                        # NOTE we make the cast op and var's dist attr as the op that consume the
                        # cast var instead of the op which generates the var
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
210 211
                            in_var.name
                        )
J
JZ-LIANG 已提交
212 213 214
                        assert in_var_dist_attr is not None
                        ref_mesh = in_var_dist_attr.process_mesh
                        ref_mapping = in_var_dist_attr.dims_mapping
215
                        consume_op_attr.set_input_dist_attr(
216 217
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
218 219 220 221 222

                        out_var = self._block.create_var(
                            name=cast_name,
                            dtype=dst_dtype,
                            persistable=False,
223 224 225 226 227
                            stop_gradient=in_var.stop_gradient,
                        )
                        set_var_dist_attr(
                            dist_context, out_var, ref_mapping, ref_mesh
                        )
J
JZ-LIANG 已提交
228

229 230 231
                        op_namescope = "/"
                        if op.has_attr('op_namescope'):
                            op_namescope = op.attr('op_namescope')
J
JZ-LIANG 已提交
232 233 234 235 236 237 238 239
                        cast_op = self._block._insert_op_without_sync(
                            idx,
                            type="cast",
                            inputs={"X": in_var},
                            outputs={"Out": out_var},
                            attrs={
                                "in_dtype": in_var.dtype,
                                "out_dtype": out_var.dtype,
240 241
                            },
                        )
242 243 244
                        cast_op._set_attr(
                            'op_namescope', op_namescope
                        )  # for recompute
J
JZ-LIANG 已提交
245
                        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
246 247
                            cast_op, ref_mesh, ref_mapping, dist_context
                        )
J
JZ-LIANG 已提交
248 249 250
                        num_cast_ops += 1
                    else:
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
251 252
                            in_var.name
                        )
253
                        consume_op_attr.set_input_dist_attr(
254 255
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
256 257 258 259
                    _rename_arg(op, in_var.name, cast_name)
                else:
                    if op.has_attr('in_dtype'):
                        op._set_attr('in_dtype', dst_dtype)
260
        self._var_name_dict[op.desc.original_id()] = var_name_dict
J
JZ-LIANG 已提交
261

262 263 264 265
        if (
            src_dtype == core.VarDesc.VarType.FP32
            and dst_dtype == core.VarDesc.VarType.FP16
        ):
J
JZ-LIANG 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
            for out_name in op.output_names:
                if _keep_fp32_output(op, out_name):
                    continue
                for out_var_name in op.output(out_name):
                    out_var = self._block.var(out_var_name)
                    if out_var.type not in _valid_types:
                        continue
                    if out_var.dtype == core.VarDesc.VarType.FP32:
                        out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
                        if op.has_attr('out_dtype'):
                            op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
        return num_cast_ops

    def cast_backward_program(self, params_grads, dist_context):
        self._block._sync_with_cpp()
        ops = self._block.ops

        loss_op = get_loss_op(self._block)
        loss_op_index = find_op_index(self._block.desc, loss_op.desc)

286
        appended_grad_times = 0
J
JZ-LIANG 已提交
287 288 289 290
        idx = loss_op_index + 1
        while idx < len(ops):
            num_cast_ops = 0
            grad_op = ops[idx]
291 292 293 294

            # NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
            # which will affect the dist_op to insert allreduce_sum op.
            op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
295 296 297
            if is_backward_op(grad_op) and (
                is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1])
            ):
298 299 300
                if not op_dist_attr.is_recompute:
                    appended_grad_times += 1

301
            grad_op_orig_id = grad_op.desc.original_id()
J
JZ-LIANG 已提交
302
            dist_op_context = dist_context.dist_op_context
303
            if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
304
                if self._is_fp16_op(grad_op_orig_id) is False:  # fp32
J
JZ-LIANG 已提交
305
                    num_cast_ops = self._insert_cast_op_backward(
306 307 308 309 310 311 312
                        grad_op,
                        idx,
                        core.VarDesc.VarType.FP16,
                        core.VarDesc.VarType.FP32,
                        dist_context,
                        appended_grad_times,
                    )
313
                elif self._is_fp16_op(grad_op_orig_id) is True:  # fp16
J
JZ-LIANG 已提交
314
                    num_cast_ops = self._insert_cast_op_backward(
315 316 317 318 319 320 321
                        grad_op,
                        idx,
                        core.VarDesc.VarType.FP32,
                        core.VarDesc.VarType.FP16,
                        dist_context,
                        appended_grad_times,
                    )
J
JZ-LIANG 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335
            elif grad_op.type == "sum":
                in_var_name = grad_op.desc.input_arg_names()[0]
                src_dtype = self._block.var(in_var_name).dtype
                for in_var_name in grad_op.desc.input_arg_names():
                    assert src_dtype == self._block.var(in_var_name).dtype
                out_var_name = grad_op.desc.output_arg_names()[0]
                out_var = self._block.var(out_var_name)
                if out_var.dtype != src_dtype:
                    out_var.desc.set_dtype(src_dtype)
            elif int(grad_op.attr('op_role')) == 257:
                pass
            else:
                raise ValueError(
                    "'{}' op is not supported in the complete amp pass.".format(
336 337 338
                        grad_op.type
                    )
                )
J
JZ-LIANG 已提交
339 340 341 342 343
            idx += num_cast_ops + 1

        self._block._sync_with_cpp()
        _update_backward_cast_ops(params_grads, dist_context)

344 345 346 347 348 349 350 351 352 353
    def _insert_cast_op_backward(
        self,
        grad_op,
        idx,
        src_dtype,
        dst_dtype,
        dist_context,
        appended_grad_times,
    ):
        """only for backward cast"""
J
JZ-LIANG 已提交
354 355 356 357 358 359 360 361 362 363 364 365 366 367

        def _keep_fp32_input(op, in_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return in_name not in {'X', 'Y@GRAD'}
            return False

        def _keep_fp32_output(op, out_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return out_name != 'X@GRAD'
            return False

        num_cast_ops = 0
368
        original_id = grad_op.desc.original_id()
J
JZ-LIANG 已提交
369
        dist_op_context = dist_context.dist_op_context
370
        fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
J
JZ-LIANG 已提交
371 372 373

        for in_name in grad_op.input_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
374 375
                grad_op, in_name
            ):
J
JZ-LIANG 已提交
376 377 378 379 380 381 382 383 384
                for in_var_name in grad_op.input(in_name):
                    in_var = self._block._find_var_recursive(in_var_name)
                    assert in_var.dtype == core.VarDesc.VarType.FP32
                continue

            for in_var_name in grad_op.input(in_name):
                in_var = self._block._find_var_recursive(in_var_name)
                if in_var.dtype == src_dtype:
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
385 386
                        grad_op
                    )
J
JZ-LIANG 已提交
387 388 389 390 391 392
                    if in_var_name in self._var_name_dict[fwd_op_id]:
                        # NOTE: if in_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr.
                        cast_name = self._var_name_dict[fwd_op_id][in_var_name]
                        grad_op.desc._rename_input(in_var_name, cast_name)
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
393 394
                            in_var_name
                        )
395
                        consume_op_attr.set_input_dist_attr(
396 397
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
398
                    else:
399 400 401 402 403 404 405 406 407
                        assert (
                            in_var.dtype == dst_dtype
                        ), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
                            grad_op.type,
                            in_name,
                            dst_dtype,
                            in_var.dtype,
                            str(grad_op),
                        )
J
JZ-LIANG 已提交
408 409 410

        for out_name in grad_op.output_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
411 412
                grad_op, out_name
            ):
J
JZ-LIANG 已提交
413 414 415 416 417 418 419
                for out_var_name in grad_op.output(out_name):
                    out_var = self._block._find_var_recursive(out_var_name)
                    assert out_var.dtype == core.VarDesc.VarType.FP32
                continue

            for out_var_name in grad_op.output(out_name):
                out_var = self._block._find_var_recursive(out_var_name)
420
                out_var_name_prefix = out_var_name[: out_var_name.find("@")]
J
JZ-LIANG 已提交
421 422 423 424 425 426 427 428 429 430
                fwd_var = self._block._find_var_recursive(out_var_name_prefix)
                # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
                if out_var.dtype != fwd_var.dtype:
                    out_var.desc.set_dtype(fwd_var.dtype)

                if out_var.dtype == src_dtype:
                    if out_var_name_prefix in self._var_name_dict[fwd_op_id]:
                        # NOTE: if out_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr, then we insert cast op to
                        # convert the cast_var to original dtype
431 432 433
                        consume_op_attr = (
                            dist_context.get_op_dist_attr_for_program(grad_op)
                        )
J
JZ-LIANG 已提交
434
                        fwd_cast_name = self._var_name_dict[fwd_op_id][
435 436
                            out_var_name_prefix
                        ]
437 438
                        suffix = ""
                        if "@RENAME" in out_var_name:
439 440 441
                            suffix = out_var_name[
                                out_var_name.find("@RENAME") :
                            ]
442
                        cast_name = fwd_cast_name + "@GRAD" + suffix
J
JZ-LIANG 已提交
443 444 445
                        cast_var = self._block.vars.get(cast_name)
                        if cast_var is None or cast_var.dtype != dst_dtype:
                            grad_op.desc._rename_output(out_var_name, cast_name)
446 447 448 449 450
                            out_var_dist_attr = (
                                consume_op_attr.get_output_dist_attr(
                                    out_var_name
                                )
                            )
J
JZ-LIANG 已提交
451 452 453
                            ref_mesh = out_var_dist_attr.process_mesh
                            ref_mapping = out_var_dist_attr.dims_mapping
                            consume_op_attr.set_output_dist_attr(
454 455
                                cast_name, out_var_dist_attr
                            )
J
JZ-LIANG 已提交
456 457 458 459 460 461
                            assert ref_mapping is not None
                            cast_var = self._block.create_var(
                                name=cast_name,
                                shape=out_var.shape,
                                dtype=dst_dtype,
                                persistable=False,
462 463 464 465 466
                                stop_gradient=out_var.stop_gradient,
                            )
                            set_var_dist_attr(
                                dist_context, cast_var, ref_mapping, ref_mesh
                            )
467
                            dist_op_context.grad_var_to_var[
468 469
                                appended_grad_times
                            ][cast_name] = fwd_cast_name
J
JZ-LIANG 已提交
470 471 472 473 474 475 476 477 478

                            cast_op = self._block._insert_op(
                                idx + 1,
                                type="cast",
                                inputs={"X": cast_var},
                                outputs={"Out": out_var},
                                attrs={
                                    "in_dtype": cast_var.dtype,
                                    "out_dtype": out_var.dtype,
479 480 481
                                    "op_role": OpRole.Backward,
                                },
                            )
J
JZ-LIANG 已提交
482 483 484 485
                            cast_op._remove_attr("op_role_var")
                            cast_op._remove_attr("op_namescope")
                            cast_op._remove_attr("with_quant_attr")
                            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
486 487
                                cast_op, ref_mesh, ref_mapping, dist_context
                            )
J
JZ-LIANG 已提交
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
                            num_cast_ops += 1
                else:
                    assert out_var.dtype == dst_dtype

        return num_cast_ops


def _update_backward_cast_ops(params_grads, dist_context):
    """
    move param grad cast to the end of backward segment
    in order to enabel fp16 allreduce
    """
    # TODO filter optimize ops in future

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    for p, g in params_grads:
        op = g.op
        if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
508 509 510
            if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
                'op_role_var'
            ):
J
JZ-LIANG 已提交
511 512 513 514
                op._remove_attr("op_role_var")

            post_ops = find_true_post_op(main_block.ops, op, g.name)
            if post_ops:
515 516 517 518 519
                raise ValueError(
                    "The cast op {0}'s output should not be"
                    "used by a non-optimize op, however, it"
                    "is used by {1}".format(op, post_ops[0])
                )
J
JZ-LIANG 已提交
520 521 522 523 524 525 526

            if op == main_block.ops[-1]:
                continue

            # add new op in the python and cpp at the same time
            new_op_desc = main_block.desc.append_op()
            new_op_desc.copy_from(op.desc)
527 528 529 530 531 532 533 534
            new_op = paddle.fluid.framework.Operator(
                block=main_block,
                desc=new_op_desc,
                type=None,
                inputs=None,
                outputs=None,
                attrs=None,
            )
J
JZ-LIANG 已提交
535 536 537 538 539
            main_block.ops.append(new_op)

            # dist attr
            param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
            output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
540 541
                main_block.var(op.output_arg_names[0])
            )
J
JZ-LIANG 已提交
542 543 544
            assert param_dist_attr is not None
            assert output_dist_attr is not None
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
545 546 547 548 549
                new_op,
                param_dist_attr.process_mesh,
                param_dist_attr.dims_mapping,
                dist_context,
            )
J
JZ-LIANG 已提交
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569

            output_dist_attr.process_mesh = param_dist_attr.process_mesh
            output_dist_attr.dims_mapping = param_dist_attr.dims_mapping

            op_idx = find_op_index(main_block.desc, op.desc)
            if op_idx == -1:
                raise ValueError("The op {0} is not in program".format(op))
            main_block._remove_op(op_idx, sync=False)

    main_block._sync_with_cpp()


def _check_and_update_gradient(params_grads, loss_scaling, dist_context):

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    grads = [g for _, g in params_grads]
    check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
    for e in grads:
570 571 572 573 574 575
        check_variable_and_dtype(
            e,
            "x",
            ['float16', 'float32', 'float64'],
            'check_finite_and_unscale',
        )
J
JZ-LIANG 已提交
576 577

    found_inf = main_block.create_var(
578 579 580
        name=unique_name.generate_with_ignorable_key(
            ".".join(['find_infinite_scale', 'tmp'])
        ),
J
JZ-LIANG 已提交
581 582 583 584
        shape=[1],
        dtype='bool',
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
585 586
        stop_gradient=False,
    )
Z
zhaoyingli 已提交
587
    set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
J
JZ-LIANG 已提交
588 589 590

    inputs = {'X': grads, 'Scale': loss_scaling}
    outputs = {'Out': grads, 'FoundInfinite': found_inf}
591
    attrs = {'op_role': OpRole.Optimize}
592 593 594 595 596 597
    new_op = main_block.append_op(
        type='check_finite_and_unscale',
        inputs=inputs,
        outputs=outputs,
        attrs=attrs,
    )
J
JZ-LIANG 已提交
598 599

    new_op_dist_attr = OperatorDistributedAttribute()
Z
zhaoyingli 已提交
600 601 602 603
    new_op_dist_attr.process_mesh = world_process_group.ranks
    new_op_dist_attr.impl_idx = 0
    if len(world_process_group.ranks) > 1:
        new_op_dist_attr.impl_type = "check_finite_and_unscale"
J
JZ-LIANG 已提交
604 605 606
    for g in grads:
        g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
        assert g_dist_attr is not None
607 608 609 610 611 612
        new_op_dist_attr.set_input_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
        new_op_dist_attr.set_output_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
J
JZ-LIANG 已提交
613 614 615 616 617 618 619
    dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
    return grads, found_inf


@register_pass("auto_parallel_amp")
class AMPPass(PassBase):
    def __init__(self):
620
        super().__init__()
J
JZ-LIANG 已提交
621 622 623 624 625 626 627 628 629 630 631
        self.set_attr("loss", None)
        self.set_attr("dist_context", None)
        self.set_attr("custom_white_list", None)
        self.set_attr("custom_black_list", None)
        self.set_attr("custom_black_varnames", None)
        self.set_attr("init_loss_scaling", 32768.0)
        self.set_attr("incr_every_n_steps", 1000)
        self.set_attr("decr_every_n_nan_or_inf", 2)
        self.set_attr("incr_ratio", 2.0)
        self.set_attr("decr_ratio", 0.8)
        self.set_attr("use_dynamic_loss_scaling", False)
632
        self.set_attr("input_data", [])
J
JZ-LIANG 已提交
633
        self.set_attr("params_grads", [])
634
        self._loss = None
J
JZ-LIANG 已提交
635 636 637
        self._loss_scaling = None
        self._num_good_steps = None
        self._num_bad_steps = None
638
        self._loss = None
J
JZ-LIANG 已提交
639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658

    def _check_self(self):
        if self.get_attr("init_loss_scaling") < 0:
            return False
        if self.get_attr("incr_every_n_steps") < 0:
            return False
        if self.get_attr("decr_every_n_nan_or_inf") < 0:
            return False
        if self.get_attr("incr_ratio") < 0:
            return False
        if self.get_attr("decr_ratio") < 0:
            return False
        if self.get_attr("dist_context") is None:
            return False
        return True

    def _check_conflict(self, other_pass):

        return True

659 660
    # NOTE: why AMPBackwardPass can override apply_single_impl instead of
    # apply_impl? AMP is an optimization pass for serial program,
J
JZ-LIANG 已提交
661 662 663 664 665 666 667 668
    # in distributed scenario, all ranks should have the same modification.
    def _apply_single_impl(self, main_program, startup_program, context):
        self.dist_context = self.get_attr("dist_context")
        params_grads = self.get_attr("params_grads")

        amp_lists = AutoMixedPrecisionLists(
            set(self.get_attr("custom_white_list")),
            set(self.get_attr("custom_black_list")),
669 670
            set(self.get_attr("custom_black_varnames")),
        )
J
JZ-LIANG 已提交
671 672

        with paddle.static.program_guard(main_program, startup_program):
Z
zhaoyingli 已提交
673 674 675
            amp_state = AMPState(main_program.global_block())
            is_train = amp_state._build_state(amp_lists, self.dist_context)

J
JZ-LIANG 已提交
676
            amp_state.cast_forward_program(self.dist_context)
Z
zhaoyingli 已提交
677 678 679 680 681 682 683

        if is_train:
            with paddle.static.program_guard(main_program, startup_program):
                amp_state.cast_backward_program(params_grads, self.dist_context)
                self._init_amp_var()
                self._scale_loss()

684 685 686 687
                if (
                    self.get_attr("use_dynamic_loss_scaling")
                    or self.get_attr("init_loss_scaling") != 1.0
                ):
Z
zhaoyingli 已提交
688
                    grads, found_inf = _check_and_update_gradient(
689 690
                        params_grads, self._loss_scaling, self.dist_context
                    )
Z
zhaoyingli 已提交
691 692 693

                if self.get_attr("use_dynamic_loss_scaling"):
                    self._update_loss_scaling(grads, found_inf)
J
JZ-LIANG 已提交
694 695 696 697 698 699 700

    def _init_amp_var(self):
        self._loss_scaling = paddle.static.create_global_var(
            name=unique_name.generate("loss_scaling"),
            shape=[1],
            value=self.get_attr("init_loss_scaling"),
            dtype='float32',
701 702 703 704 705 706 707 708
            persistable=True,
        )
        set_var_dist_attr(
            self.dist_context,
            self._loss_scaling,
            [-1],
            world_process_group.ranks,
        )
J
JZ-LIANG 已提交
709 710 711 712 713 714 715

        if self.get_attr("use_dynamic_loss_scaling"):
            self._num_good_steps = paddle.static.create_global_var(
                name=unique_name.generate("num_good_steps"),
                shape=[1],
                value=0,
                dtype='int32',
716 717 718 719 720 721 722 723
                persistable=True,
            )
            set_var_dist_attr(
                self.dist_context,
                self._num_good_steps,
                [-1],
                world_process_group.ranks,
            )
J
JZ-LIANG 已提交
724 725 726 727 728 729

            self._num_bad_steps = paddle.static.create_global_var(
                name=unique_name.generate("num_bad_steps"),
                shape=[1],
                value=0,
                dtype='int32',
730 731 732 733 734 735 736 737
                persistable=True,
            )
            set_var_dist_attr(
                self.dist_context,
                self._num_bad_steps,
                [-1],
                world_process_group.ranks,
            )
J
JZ-LIANG 已提交
738 739 740 741 742

    def _scale_loss(self):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()
743 744
        OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()

J
JZ-LIANG 已提交
745 746 747 748
        loss = self.get_attr("loss")
        assert loss is not None
        loss_op = loss.op
        loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
749 750
            loss_op
        )
J
JZ-LIANG 已提交
751 752

        if loss.dtype != core.VarDesc.VarType.FP32:
753 754

            tmp_name = unique_name.generate(loss.name + ".cast_fp32")
755 756 757
            cast_loss = main_block.create_var(
                name=tmp_name, dtype=core.VarDesc.VarType.FP32
            )
758
            loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
759 760
                loss
            )
761
            ref_mesh = loss_op_dist_attr.process_mesh
762
            self.dist_context.set_tensor_dist_attr_for_program(
763 764
                cast_loss, loss_dist_attr
            )
765

766
            # forward
767 768 769 770 771 772 773 774 775 776
            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
            cast_op = main_block._insert_op(
                loss_op_idx + 1,
                type='cast',
                inputs={'X': [loss]},
                outputs={'Out': [cast_loss]},
                attrs={
                    "in_dtype": loss.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                    'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
777 778
                },
            )
779

780 781 782
            loss_op._set_attr(
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
            )
783
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
784 785
                cast_op, ref_mesh, [-1], self.dist_context
            )
786 787 788

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
789 790 791 792
            assert (
                first_backward_op.type == "fill_constant"
                and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            )
793 794 795 796
            cast_loss_grad = main_block.create_var(
                name=unique_name.generate(tmp_name + "@GRAD"),
                shape=loss.shape,
                dtype=core.VarDesc.VarType.FP32,
797 798
                persistable=loss.persistable,
            )
799 800 801 802 803 804 805 806 807 808 809 810 811
            set_var_dist_attr(self.dist_context, cast_loss_grad, [-1], ref_mesh)

            pre_grad_name = first_backward_op.output_arg_names[0]
            first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
            cast_grad_op = main_block._insert_op(
                loss_op_idx + 3,
                type='cast',
                inputs={'X': [cast_loss_grad]},
                outputs={'Out': [pre_grad_name]},
                attrs={
                    "in_dtype": core.VarDesc.VarType.FP32,
                    "out_dtype": core.VarDesc.VarType.FP16,
                    'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
812 813
                },
            )
814
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
815 816
                cast_grad_op, ref_mesh, [-1], self.dist_context
            )
817 818
            loss_op = cast_op
            loss = cast_loss
J
JZ-LIANG 已提交
819

820 821 822 823
        if (
            self.get_attr("use_dynamic_loss_scaling")
            or self.get_attr("init_loss_scaling") != 1.0
        ):
J
JZ-LIANG 已提交
824 825 826 827 828 829 830 831 832

            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)

            # forward
            ref_mesh = loss_op_dist_attr.process_mesh
            self._scaled_loss = main_block.create_var(
                name=unique_name.generate("scaled_loss"),
                shape=loss.shape,
                dtype=loss.dtype,
833 834 835 836 837
                persistable=loss.persistable,
            )
            set_var_dist_attr(
                self.dist_context, self._scaled_loss, [-1], ref_mesh
            )
J
JZ-LIANG 已提交
838 839 840 841

            elementwise_mul_op = main_block._insert_op(
                loss_op_idx + 1,
                type='elementwise_mul',
842
                inputs={'X': [loss], 'Y': [self._loss_scaling]},
J
JZ-LIANG 已提交
843
                outputs={'Out': [self._scaled_loss]},
844 845
                attrs={
                    'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
846 847 848 849 850
                },
            )
            loss_op._set_attr(
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
            )
J
JZ-LIANG 已提交
851
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
852 853
                elementwise_mul_op, ref_mesh, [-1], self.dist_context
            )
J
JZ-LIANG 已提交
854 855 856

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
857 858 859 860
            assert (
                first_backward_op.type == "fill_constant"
                and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            )
J
JZ-LIANG 已提交
861 862 863 864
            self._scaled_loss_grad = main_block.create_var(
                name=unique_name.generate("scaled_loss") + "@GRAD",
                shape=loss.shape,
                dtype=loss.dtype,
865 866 867 868 869
                persistable=loss.persistable,
            )
            set_var_dist_attr(
                self.dist_context, self._scaled_loss_grad, [-1], ref_mesh
            )
J
JZ-LIANG 已提交
870
            pre_grad_name = first_backward_op.output_arg_names[0]
871 872 873
            first_backward_op._rename_output(
                pre_grad_name, self._scaled_loss_grad.name
            )
J
JZ-LIANG 已提交
874 875 876
            # FIXME(JZ-LIANG) a trick to insert backward op
            main_block._sync_with_cpp()
            elementwise_mul_grad_op_desc = main_block.desc._insert_op(
877 878
                loss_op_idx + 3
            )
J
JZ-LIANG 已提交
879 880
            elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
            elementwise_mul_grad_op_desc.set_input(
881 882
                'Out@GRAD', [self._scaled_loss_grad.name]
            )
J
JZ-LIANG 已提交
883
            elementwise_mul_grad_op_desc.set_input('X', [loss.name])
884 885 886
            elementwise_mul_grad_op_desc.set_input(
                'Y', [self._loss_scaling.name]
            )
J
JZ-LIANG 已提交
887 888 889
            elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
            elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
            elementwise_mul_grad_op_desc._set_attr(
890 891
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward
            )
J
JZ-LIANG 已提交
892 893
            elementwise_mul_grad_op_desc._set_attr('axis', -1)
            elementwise_mul_grad_op = paddle.fluid.framework.Operator(
894 895
                main_block, elementwise_mul_grad_op_desc
            )
J
JZ-LIANG 已提交
896 897 898 899 900
            main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op)
            main_block._sync_with_cpp()
            elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
            assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
901 902
                elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context
            )
J
JZ-LIANG 已提交
903 904 905

        else:
            self._scaled_loss = loss
906
        self._loss = loss
J
JZ-LIANG 已提交
907 908 909 910 911 912 913
        main_block._sync_with_cpp()

    def _update_loss_scaling(self, grads, found_inf):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()

914 915 916 917 918 919
        check_variable_and_dtype(
            self._loss_scaling,
            "prev_loss_scaling",
            ['float32', 'float64'],
            "update_loss_scaling",
        )
J
JZ-LIANG 已提交
920 921
        check_type(grads, 'x', (tuple, list), 'update_loss_scaling')
        for e in grads:
922 923 924
            check_variable_and_dtype(
                e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
            )
925
            if e.dtype == core.VarDesc.VarType.FP16:
926 927 928
                assert (
                    self._loss_scaling.dtype == core.VarDesc.VarType.FP32
                ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
929
            else:
930 931 932
                assert (
                    self._loss_scaling.dtype == e.dtype
                ), "The dtype of prev_loss_scaling should be equal to the dtype of x."
J
JZ-LIANG 已提交
933 934 935 936 937 938

        inputs = {
            'X': grads,
            'FoundInfinite': found_inf,
            'PrevLossScaling': self._loss_scaling,
            'InGoodSteps': self._num_good_steps,
939
            'InBadSteps': self._num_bad_steps,
J
JZ-LIANG 已提交
940 941 942 943 944 945
        }

        outputs = {
            'Out': grads,
            'LossScaling': self._loss_scaling,
            'OutGoodSteps': self._num_good_steps,
946
            'OutBadSteps': self._num_bad_steps,
J
JZ-LIANG 已提交
947 948 949 950 951 952 953 954
        }

        attrs = {
            'incr_every_n_steps': self.get_attr("incr_every_n_steps"),
            'decr_every_n_nan_or_inf': self.get_attr("decr_every_n_nan_or_inf"),
            'incr_ratio': self.get_attr("incr_ratio"),
            'decr_ratio': self.get_attr("decr_ratio"),
            'stop_update': self.get_attr("stop_update"),
955
            'op_role': OpRole.Optimize,
J
JZ-LIANG 已提交
956 957
        }

958 959 960 961 962 963
        new_op = main_block.append_op(
            type='update_loss_scaling',
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
        )
J
JZ-LIANG 已提交
964 965

        new_op_dist_attr = OperatorDistributedAttribute()
Z
zhaoyingli 已提交
966 967 968 969
        new_op_dist_attr.process_mesh = world_process_group.ranks
        new_op_dist_attr.impl_idx = 0
        if len(world_process_group.ranks) > 1:
            new_op_dist_attr.impl_type = "update_loss_scaling"
J
JZ-LIANG 已提交
970 971 972
        for g in grads:
            g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
            assert g_dist_attr is not None
973 974 975 976 977 978
            new_op_dist_attr.set_input_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
            new_op_dist_attr.set_output_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
J
JZ-LIANG 已提交
979 980 981
        self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)

        main_block._sync_with_cpp()
982 983 984 985 986 987 988 989 990 991

    def get_loss(self):
        # the amp / fp16 might change the effective loss variable for network and
        # therefore would affect the subsequent passes that rely on the loss.
        # return the effective loss after amp / fp16 pass.

        if self._loss:
            return self._loss
        else:
            return self.get_attr("loss")