auto_parallel_amp.py 40.0 KB
Newer Older
J
JZ-LIANG 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
J
JZ-LIANG 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
J
JZ-LIANG 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
J
JZ-LIANG 已提交
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr
17 18 19
from paddle.distributed.auto_parallel.process_group import (
    get_world_process_group,
)
20 21 22 23
from paddle.distributed.auto_parallel.utils import (
    get_loss_op,
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
    set_var_dist_attr,
24
)
25
from paddle.distributed.fleet.meta_optimizers.common import OpRole
26 27 28
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.framework import core
from paddle.static.amp.fp16_utils import (
29 30 31
    AutoMixedPrecisionLists,
    _dtype_to_str,
    _is_in_black_varnames,
32 33
    _keep_fp32_input,
    _keep_fp32_output,
34
    _rename_arg,
35
    _valid_types,
36
    find_op_index,
37 38 39
    find_true_post_op,
    find_true_prev_op,
)
40
from paddle.utils import unique_name
41

42
from ..auto_parallel.process_mesh import ProcessMesh
43 44
from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op
from .pass_base import PassBase, register_pass
45

Z
zhaoyingli 已提交
46
world_process_group = get_world_process_group()
J
JZ-LIANG 已提交
47 48


49
class AMPState:
J
JZ-LIANG 已提交
50 51
    def __init__(self, block):
        self._block = block
52 53 54
        self._op_fp16_dict = (
            {}
        )  # op_id --> True/False. 'True' means that the current op is in fp16 mode.
J
JZ-LIANG 已提交
55
        self._var_name_dict = {}  # fwd_op_id --> {old_name: cast_name}
Z
zhaoyingli 已提交
56
        self.is_train = False
J
JZ-LIANG 已提交
57 58 59 60

    def _is_fp16_op(self, op_id):
        return self._op_fp16_dict.get(op_id, None)

Z
zhaoyingli 已提交
61
    def _build_state(self, amp_lists, dist_context):
J
JZ-LIANG 已提交
62 63 64
        ops = self._block.ops
        dist_op_context = dist_context.dist_op_context
        for op in ops:
Z
zhaoyingli 已提交
65 66 67
            if int(op.attr('op_role')) == 257:
                self.is_train = True

J
JZ-LIANG 已提交
68 69 70
            if int(op.attr('op_role')) == int(OpRole.Forward):
                self._mark_black_white_ops(amp_lists)
            elif int(op.attr('op_role')) == int(OpRole.Backward):
71 72
                if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
                    fwd_op_id = dist_op_context.grad_op_id_to_op_id[
73 74
                        op.desc.original_id()
                    ]
75
                    if self._is_fp16_op(fwd_op_id) is True:
76
                        self._op_fp16_dict[op.desc.original_id()] = True
77
                    elif self._is_fp16_op(fwd_op_id) is False:
78
                        self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
79 80 81
            elif int(op.attr('op_role')) == int(OpRole.Optimize):
                break

Z
zhaoyingli 已提交
82 83
        return self.is_train

J
JZ-LIANG 已提交
84 85
    def _mark_black_white_ops(self, amp_lists):
        """
86
        this function is modified from paddle.static.amp
J
JZ-LIANG 已提交
87 88 89 90 91 92 93 94 95 96
        """
        self._block._sync_with_cpp()
        ops = self._block.ops

        for op in ops:
            if int(op.attr('op_role')) == int(OpRole.Backward):
                break
            if op.type == 'create_py_reader' or op.type == 'read':
                continue
            if amp_lists.black_varnames is not None and _is_in_black_varnames(
97 98
                op, amp_lists
            ):
99
                self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
100 101
                continue
            if op.type in amp_lists.black_list:
102
                self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
103
            elif op.type in amp_lists.white_list:
104
                self._op_fp16_dict[op.desc.original_id()] = True
J
JZ-LIANG 已提交
105 106 107 108 109 110 111 112 113 114 115 116
            elif op.type in amp_lists.gray_list:
                is_black_op = False
                is_white_op = False
                for in_name in op.input_names:
                    # if this op has inputs
                    if in_name:
                        for in_var_name in op.input(in_name):
                            in_var = self._block.var(in_var_name)
                            # this in_var isn't the output of other op
                            if in_var.op is None:
                                continue
                            elif in_var.op is op:
117
                                prev_op = find_true_prev_op(
118 119
                                    ops, op, in_var_name
                                )
J
JZ-LIANG 已提交
120 121 122 123 124
                                if prev_op is None:
                                    continue
                            else:
                                prev_op = in_var.op
                            # if it's one of inputs
125 126
                            if (
                                self._is_fp16_op(prev_op.desc.original_id())
127
                                is False
128 129
                                or prev_op.type in amp_lists.black_list
                            ):
J
JZ-LIANG 已提交
130
                                is_black_op = True
131 132
                            elif (
                                self._is_fp16_op(prev_op.desc.original_id())
133
                                is True
134 135
                                or prev_op.type in amp_lists.white_list
                            ):
J
JZ-LIANG 已提交
136 137
                                is_white_op = True
                if is_black_op:
138
                    self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
139
                elif is_white_op:
140
                    self._op_fp16_dict[op.desc.original_id()] = True
J
JZ-LIANG 已提交
141 142 143 144 145
                else:
                    pass
            else:
                # For numerical safe, we apply fp32 computation on ops that
                # are not determined which list they should stay.
146
                self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
147 148 149 150 151 152 153 154 155

    def cast_forward_program(self, dist_context):
        ops = self._block.ops
        idx = 0
        while idx < len(ops):
            op = ops[idx]
            num_cast_ops = 0
            if int(op.attr('op_role')) == int(OpRole.Backward):
                break
156
            if self._is_fp16_op(op.desc.original_id()) is False:
J
JZ-LIANG 已提交
157
                num_cast_ops = self._insert_cast_op_forward(
158 159 160 161 162 163
                    op,
                    idx,
                    core.VarDesc.VarType.FP16,
                    core.VarDesc.VarType.FP32,
                    dist_context,
                )
164
            elif self._is_fp16_op(op.desc.original_id()) is True:
J
JZ-LIANG 已提交
165
                num_cast_ops = self._insert_cast_op_forward(
166 167 168 169 170 171
                    op,
                    idx,
                    core.VarDesc.VarType.FP32,
                    core.VarDesc.VarType.FP16,
                    dist_context,
                )
J
JZ-LIANG 已提交
172 173 174 175 176
            else:
                pass
            idx += num_cast_ops + 1
        self._block._sync_with_cpp()

177 178 179
    def _insert_cast_op_forward(
        self, op, idx, src_dtype, dst_dtype, dist_context
    ):
J
JZ-LIANG 已提交
180 181
        """
        only for forward cast
182
        modified from paddle.static.amp
J
JZ-LIANG 已提交
183 184
        """
        num_cast_ops = 0
185
        var_name_dict = {}
J
JZ-LIANG 已提交
186 187
        for in_name in op.input_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
188 189
                op, in_name
            ):
J
JZ-LIANG 已提交
190 191 192 193 194 195
                continue
            for in_var_name in op.input(in_name):
                in_var = self._block._find_var_recursive(in_var_name)
                if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
                    continue
                if in_var.dtype == src_dtype:
196 197 198
                    cast_name = (
                        in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
                    )
J
JZ-LIANG 已提交
199 200 201
                    out_var = self._block.vars.get(cast_name)
                    var_name_dict[in_var.name] = cast_name
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
202 203
                        op
                    )
J
JZ-LIANG 已提交
204 205 206 207 208
                    assert consume_op_attr is not None
                    if out_var is None or out_var.dtype != dst_dtype:
                        # NOTE we make the cast op and var's dist attr as the op that consume the
                        # cast var instead of the op which generates the var
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
209 210
                            in_var.name
                        )
J
JZ-LIANG 已提交
211 212 213
                        assert in_var_dist_attr is not None
                        ref_mesh = in_var_dist_attr.process_mesh
                        ref_mapping = in_var_dist_attr.dims_mapping
214
                        consume_op_attr.set_input_dist_attr(
215 216
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
217 218 219 220 221

                        out_var = self._block.create_var(
                            name=cast_name,
                            dtype=dst_dtype,
                            persistable=False,
222 223 224 225 226
                            stop_gradient=in_var.stop_gradient,
                        )
                        set_var_dist_attr(
                            dist_context, out_var, ref_mapping, ref_mesh
                        )
J
JZ-LIANG 已提交
227

228 229 230
                        op_namescope = "/"
                        if op.has_attr('op_namescope'):
                            op_namescope = op.attr('op_namescope')
J
JZ-LIANG 已提交
231 232 233 234 235 236 237 238
                        cast_op = self._block._insert_op_without_sync(
                            idx,
                            type="cast",
                            inputs={"X": in_var},
                            outputs={"Out": out_var},
                            attrs={
                                "in_dtype": in_var.dtype,
                                "out_dtype": out_var.dtype,
239 240
                            },
                        )
241 242 243
                        cast_op._set_attr(
                            'op_namescope', op_namescope
                        )  # for recompute
J
JZ-LIANG 已提交
244
                        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
245 246
                            cast_op, ref_mesh, ref_mapping, dist_context
                        )
J
JZ-LIANG 已提交
247 248 249
                        num_cast_ops += 1
                    else:
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
250 251
                            in_var.name
                        )
252
                        consume_op_attr.set_input_dist_attr(
253 254
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
255 256 257 258
                    _rename_arg(op, in_var.name, cast_name)
                else:
                    if op.has_attr('in_dtype'):
                        op._set_attr('in_dtype', dst_dtype)
259
        self._var_name_dict[op.desc.original_id()] = var_name_dict
J
JZ-LIANG 已提交
260

261 262 263 264
        if (
            src_dtype == core.VarDesc.VarType.FP32
            and dst_dtype == core.VarDesc.VarType.FP16
        ):
J
JZ-LIANG 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
            for out_name in op.output_names:
                if _keep_fp32_output(op, out_name):
                    continue
                for out_var_name in op.output(out_name):
                    out_var = self._block.var(out_var_name)
                    if out_var.type not in _valid_types:
                        continue
                    if out_var.dtype == core.VarDesc.VarType.FP32:
                        out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
                        if op.has_attr('out_dtype'):
                            op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
        return num_cast_ops

    def cast_backward_program(self, params_grads, dist_context):
        self._block._sync_with_cpp()
        ops = self._block.ops

        loss_op = get_loss_op(self._block)
        loss_op_index = find_op_index(self._block.desc, loss_op.desc)

285
        appended_grad_times = 0
J
JZ-LIANG 已提交
286 287 288 289
        idx = loss_op_index + 1
        while idx < len(ops):
            num_cast_ops = 0
            grad_op = ops[idx]
290 291 292 293

            # NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
            # which will affect the dist_op to insert allreduce_sum op.
            op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
294 295 296
            if is_backward_op(grad_op) and (
                is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1])
            ):
297 298 299
                if not op_dist_attr.is_recompute:
                    appended_grad_times += 1

300
            grad_op_orig_id = grad_op.desc.original_id()
J
JZ-LIANG 已提交
301
            dist_op_context = dist_context.dist_op_context
302
            if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
303
                if self._is_fp16_op(grad_op_orig_id) is False:  # fp32
J
JZ-LIANG 已提交
304
                    num_cast_ops = self._insert_cast_op_backward(
305 306 307 308 309 310 311
                        grad_op,
                        idx,
                        core.VarDesc.VarType.FP16,
                        core.VarDesc.VarType.FP32,
                        dist_context,
                        appended_grad_times,
                    )
312
                elif self._is_fp16_op(grad_op_orig_id) is True:  # fp16
J
JZ-LIANG 已提交
313
                    num_cast_ops = self._insert_cast_op_backward(
314 315 316 317 318 319 320
                        grad_op,
                        idx,
                        core.VarDesc.VarType.FP32,
                        core.VarDesc.VarType.FP16,
                        dist_context,
                        appended_grad_times,
                    )
J
JZ-LIANG 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334
            elif grad_op.type == "sum":
                in_var_name = grad_op.desc.input_arg_names()[0]
                src_dtype = self._block.var(in_var_name).dtype
                for in_var_name in grad_op.desc.input_arg_names():
                    assert src_dtype == self._block.var(in_var_name).dtype
                out_var_name = grad_op.desc.output_arg_names()[0]
                out_var = self._block.var(out_var_name)
                if out_var.dtype != src_dtype:
                    out_var.desc.set_dtype(src_dtype)
            elif int(grad_op.attr('op_role')) == 257:
                pass
            else:
                raise ValueError(
                    "'{}' op is not supported in the complete amp pass.".format(
335 336 337
                        grad_op.type
                    )
                )
J
JZ-LIANG 已提交
338 339 340 341 342
            idx += num_cast_ops + 1

        self._block._sync_with_cpp()
        _update_backward_cast_ops(params_grads, dist_context)

343 344 345 346 347 348 349 350 351 352
    def _insert_cast_op_backward(
        self,
        grad_op,
        idx,
        src_dtype,
        dst_dtype,
        dist_context,
        appended_grad_times,
    ):
        """only for backward cast"""
J
JZ-LIANG 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365 366

        def _keep_fp32_input(op, in_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return in_name not in {'X', 'Y@GRAD'}
            return False

        def _keep_fp32_output(op, out_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return out_name != 'X@GRAD'
            return False

        num_cast_ops = 0
367
        original_id = grad_op.desc.original_id()
J
JZ-LIANG 已提交
368
        dist_op_context = dist_context.dist_op_context
369
        fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
J
JZ-LIANG 已提交
370 371 372

        for in_name in grad_op.input_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
373 374
                grad_op, in_name
            ):
J
JZ-LIANG 已提交
375 376 377 378 379 380 381 382 383
                for in_var_name in grad_op.input(in_name):
                    in_var = self._block._find_var_recursive(in_var_name)
                    assert in_var.dtype == core.VarDesc.VarType.FP32
                continue

            for in_var_name in grad_op.input(in_name):
                in_var = self._block._find_var_recursive(in_var_name)
                if in_var.dtype == src_dtype:
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
384 385
                        grad_op
                    )
J
JZ-LIANG 已提交
386 387 388 389 390 391
                    if in_var_name in self._var_name_dict[fwd_op_id]:
                        # NOTE: if in_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr.
                        cast_name = self._var_name_dict[fwd_op_id][in_var_name]
                        grad_op.desc._rename_input(in_var_name, cast_name)
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
392 393
                            in_var_name
                        )
394
                        consume_op_attr.set_input_dist_attr(
395 396
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
397
                    else:
398 399 400 401 402 403 404 405 406
                        assert (
                            in_var.dtype == dst_dtype
                        ), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
                            grad_op.type,
                            in_name,
                            dst_dtype,
                            in_var.dtype,
                            str(grad_op),
                        )
J
JZ-LIANG 已提交
407 408 409

        for out_name in grad_op.output_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
410 411
                grad_op, out_name
            ):
J
JZ-LIANG 已提交
412 413 414 415 416 417 418
                for out_var_name in grad_op.output(out_name):
                    out_var = self._block._find_var_recursive(out_var_name)
                    assert out_var.dtype == core.VarDesc.VarType.FP32
                continue

            for out_var_name in grad_op.output(out_name):
                out_var = self._block._find_var_recursive(out_var_name)
419
                out_var_name_prefix = out_var_name[: out_var_name.find("@")]
J
JZ-LIANG 已提交
420 421 422 423 424 425 426 427 428 429
                fwd_var = self._block._find_var_recursive(out_var_name_prefix)
                # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
                if out_var.dtype != fwd_var.dtype:
                    out_var.desc.set_dtype(fwd_var.dtype)

                if out_var.dtype == src_dtype:
                    if out_var_name_prefix in self._var_name_dict[fwd_op_id]:
                        # NOTE: if out_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr, then we insert cast op to
                        # convert the cast_var to original dtype
430 431 432
                        consume_op_attr = (
                            dist_context.get_op_dist_attr_for_program(grad_op)
                        )
J
JZ-LIANG 已提交
433
                        fwd_cast_name = self._var_name_dict[fwd_op_id][
434 435
                            out_var_name_prefix
                        ]
436 437
                        suffix = ""
                        if "@RENAME" in out_var_name:
438 439 440
                            suffix = out_var_name[
                                out_var_name.find("@RENAME") :
                            ]
441
                        cast_name = fwd_cast_name + "@GRAD" + suffix
J
JZ-LIANG 已提交
442 443 444
                        cast_var = self._block.vars.get(cast_name)
                        if cast_var is None or cast_var.dtype != dst_dtype:
                            grad_op.desc._rename_output(out_var_name, cast_name)
445 446 447 448 449
                            out_var_dist_attr = (
                                consume_op_attr.get_output_dist_attr(
                                    out_var_name
                                )
                            )
J
JZ-LIANG 已提交
450 451 452
                            ref_mesh = out_var_dist_attr.process_mesh
                            ref_mapping = out_var_dist_attr.dims_mapping
                            consume_op_attr.set_output_dist_attr(
453 454
                                cast_name, out_var_dist_attr
                            )
J
JZ-LIANG 已提交
455 456 457 458 459 460
                            assert ref_mapping is not None
                            cast_var = self._block.create_var(
                                name=cast_name,
                                shape=out_var.shape,
                                dtype=dst_dtype,
                                persistable=False,
461 462 463 464 465
                                stop_gradient=out_var.stop_gradient,
                            )
                            set_var_dist_attr(
                                dist_context, cast_var, ref_mapping, ref_mesh
                            )
466
                            dist_op_context.grad_var_to_var[
467 468
                                appended_grad_times
                            ][cast_name] = fwd_cast_name
J
JZ-LIANG 已提交
469 470 471 472 473 474 475 476 477

                            cast_op = self._block._insert_op(
                                idx + 1,
                                type="cast",
                                inputs={"X": cast_var},
                                outputs={"Out": out_var},
                                attrs={
                                    "in_dtype": cast_var.dtype,
                                    "out_dtype": out_var.dtype,
478 479 480
                                    "op_role": OpRole.Backward,
                                },
                            )
J
JZ-LIANG 已提交
481 482 483 484
                            cast_op._remove_attr("op_role_var")
                            cast_op._remove_attr("op_namescope")
                            cast_op._remove_attr("with_quant_attr")
                            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
485 486
                                cast_op, ref_mesh, ref_mapping, dist_context
                            )
J
JZ-LIANG 已提交
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506
                            num_cast_ops += 1
                else:
                    assert out_var.dtype == dst_dtype

        return num_cast_ops


def _update_backward_cast_ops(params_grads, dist_context):
    """
    move param grad cast to the end of backward segment
    in order to enabel fp16 allreduce
    """
    # TODO filter optimize ops in future

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    for p, g in params_grads:
        op = g.op
        if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
507 508 509
            if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
                'op_role_var'
            ):
J
JZ-LIANG 已提交
510 511 512 513
                op._remove_attr("op_role_var")

            post_ops = find_true_post_op(main_block.ops, op, g.name)
            if post_ops:
514 515 516 517 518
                raise ValueError(
                    "The cast op {0}'s output should not be"
                    "used by a non-optimize op, however, it"
                    "is used by {1}".format(op, post_ops[0])
                )
J
JZ-LIANG 已提交
519 520 521 522 523 524 525

            if op == main_block.ops[-1]:
                continue

            # add new op in the python and cpp at the same time
            new_op_desc = main_block.desc.append_op()
            new_op_desc.copy_from(op.desc)
526
            new_op = paddle.static.Operator(
527 528 529 530 531 532 533
                block=main_block,
                desc=new_op_desc,
                type=None,
                inputs=None,
                outputs=None,
                attrs=None,
            )
J
JZ-LIANG 已提交
534 535 536 537 538
            main_block.ops.append(new_op)

            # dist attr
            param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
            output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
539 540
                main_block.var(op.output_arg_names[0])
            )
J
JZ-LIANG 已提交
541 542 543
            assert param_dist_attr is not None
            assert output_dist_attr is not None
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
544 545 546 547 548
                new_op,
                param_dist_attr.process_mesh,
                param_dist_attr.dims_mapping,
                dist_context,
            )
J
JZ-LIANG 已提交
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568

            output_dist_attr.process_mesh = param_dist_attr.process_mesh
            output_dist_attr.dims_mapping = param_dist_attr.dims_mapping

            op_idx = find_op_index(main_block.desc, op.desc)
            if op_idx == -1:
                raise ValueError("The op {0} is not in program".format(op))
            main_block._remove_op(op_idx, sync=False)

    main_block._sync_with_cpp()


def _check_and_update_gradient(params_grads, loss_scaling, dist_context):

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    grads = [g for _, g in params_grads]
    check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
    for e in grads:
569 570 571 572 573 574
        check_variable_and_dtype(
            e,
            "x",
            ['float16', 'float32', 'float64'],
            'check_finite_and_unscale',
        )
J
JZ-LIANG 已提交
575 576

    found_inf = main_block.create_var(
577 578 579
        name=unique_name.generate_with_ignorable_key(
            ".".join(['find_infinite_scale', 'tmp'])
        ),
J
JZ-LIANG 已提交
580 581 582 583
        shape=[1],
        dtype='bool',
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
584 585
        stop_gradient=False,
    )
Z
zhaoyingli 已提交
586
    set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
J
JZ-LIANG 已提交
587 588 589

    inputs = {'X': grads, 'Scale': loss_scaling}
    outputs = {'Out': grads, 'FoundInfinite': found_inf}
590
    attrs = {'op_role': OpRole.Optimize}
591 592 593 594 595 596
    new_op = main_block.append_op(
        type='check_finite_and_unscale',
        inputs=inputs,
        outputs=outputs,
        attrs=attrs,
    )
J
JZ-LIANG 已提交
597

598 599 600 601
    # Constructing dist attr from op_desc can
    # give all inputs and outputs default dist attrs
    new_op_dist_attr = OperatorDistAttr(new_op.desc)
    new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
Z
zhaoyingli 已提交
602 603 604
    new_op_dist_attr.impl_idx = 0
    if len(world_process_group.ranks) > 1:
        new_op_dist_attr.impl_type = "check_finite_and_unscale"
J
JZ-LIANG 已提交
605 606 607
    for g in grads:
        g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
        assert g_dist_attr is not None
608 609 610 611 612 613
        new_op_dist_attr.set_input_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
        new_op_dist_attr.set_output_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
J
JZ-LIANG 已提交
614 615 616 617 618 619 620
    dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
    return grads, found_inf


@register_pass("auto_parallel_amp")
class AMPPass(PassBase):
    def __init__(self):
621
        super().__init__()
J
JZ-LIANG 已提交
622 623 624 625 626 627 628 629 630 631 632
        self.set_attr("loss", None)
        self.set_attr("dist_context", None)
        self.set_attr("custom_white_list", None)
        self.set_attr("custom_black_list", None)
        self.set_attr("custom_black_varnames", None)
        self.set_attr("init_loss_scaling", 32768.0)
        self.set_attr("incr_every_n_steps", 1000)
        self.set_attr("decr_every_n_nan_or_inf", 2)
        self.set_attr("incr_ratio", 2.0)
        self.set_attr("decr_ratio", 0.8)
        self.set_attr("use_dynamic_loss_scaling", False)
633
        self.set_attr("input_data", [])
J
JZ-LIANG 已提交
634
        self.set_attr("params_grads", [])
635
        self.set_attr("dtype", "")  # fp16/bf16
636
        self._loss = None
J
JZ-LIANG 已提交
637 638 639
        self._loss_scaling = None
        self._num_good_steps = None
        self._num_bad_steps = None
640
        self._loss = None
J
JZ-LIANG 已提交
641 642

    def _check_self(self):
643 644
        if self.get_attr("dtype") not in ["float16", "bfloat16"]:
            return False
J
JZ-LIANG 已提交
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662
        if self.get_attr("init_loss_scaling") < 0:
            return False
        if self.get_attr("incr_every_n_steps") < 0:
            return False
        if self.get_attr("decr_every_n_nan_or_inf") < 0:
            return False
        if self.get_attr("incr_ratio") < 0:
            return False
        if self.get_attr("decr_ratio") < 0:
            return False
        if self.get_attr("dist_context") is None:
            return False
        return True

    def _check_conflict(self, other_pass):

        return True

663 664
    # NOTE: why AMPBackwardPass can override apply_single_impl instead of
    # apply_impl? AMP is an optimization pass for serial program,
J
JZ-LIANG 已提交
665 666 667 668 669 670 671 672
    # in distributed scenario, all ranks should have the same modification.
    def _apply_single_impl(self, main_program, startup_program, context):
        self.dist_context = self.get_attr("dist_context")
        params_grads = self.get_attr("params_grads")

        amp_lists = AutoMixedPrecisionLists(
            set(self.get_attr("custom_white_list")),
            set(self.get_attr("custom_black_list")),
673 674
            set(self.get_attr("custom_black_varnames")),
        )
J
JZ-LIANG 已提交
675 676

        with paddle.static.program_guard(main_program, startup_program):
Z
zhaoyingli 已提交
677 678 679
            amp_state = AMPState(main_program.global_block())
            is_train = amp_state._build_state(amp_lists, self.dist_context)

J
JZ-LIANG 已提交
680
            amp_state.cast_forward_program(self.dist_context)
Z
zhaoyingli 已提交
681 682 683 684 685 686 687

        if is_train:
            with paddle.static.program_guard(main_program, startup_program):
                amp_state.cast_backward_program(params_grads, self.dist_context)
                self._init_amp_var()
                self._scale_loss()

688 689 690 691
                if (
                    self.get_attr("use_dynamic_loss_scaling")
                    or self.get_attr("init_loss_scaling") != 1.0
                ):
Z
zhaoyingli 已提交
692
                    grads, found_inf = _check_and_update_gradient(
693 694
                        params_grads, self._loss_scaling, self.dist_context
                    )
Z
zhaoyingli 已提交
695 696 697

                if self.get_attr("use_dynamic_loss_scaling"):
                    self._update_loss_scaling(grads, found_inf)
J
JZ-LIANG 已提交
698 699 700 701 702 703 704

    def _init_amp_var(self):
        self._loss_scaling = paddle.static.create_global_var(
            name=unique_name.generate("loss_scaling"),
            shape=[1],
            value=self.get_attr("init_loss_scaling"),
            dtype='float32',
705 706 707 708 709 710 711 712
            persistable=True,
        )
        set_var_dist_attr(
            self.dist_context,
            self._loss_scaling,
            [-1],
            world_process_group.ranks,
        )
J
JZ-LIANG 已提交
713 714 715 716 717 718 719

        if self.get_attr("use_dynamic_loss_scaling"):
            self._num_good_steps = paddle.static.create_global_var(
                name=unique_name.generate("num_good_steps"),
                shape=[1],
                value=0,
                dtype='int32',
720 721 722 723 724 725 726 727
                persistable=True,
            )
            set_var_dist_attr(
                self.dist_context,
                self._num_good_steps,
                [-1],
                world_process_group.ranks,
            )
J
JZ-LIANG 已提交
728 729 730 731 732 733

            self._num_bad_steps = paddle.static.create_global_var(
                name=unique_name.generate("num_bad_steps"),
                shape=[1],
                value=0,
                dtype='int32',
734 735 736 737 738 739 740 741
                persistable=True,
            )
            set_var_dist_attr(
                self.dist_context,
                self._num_bad_steps,
                [-1],
                world_process_group.ranks,
            )
J
JZ-LIANG 已提交
742 743 744 745 746

    def _scale_loss(self):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()
747 748
        OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()

J
JZ-LIANG 已提交
749 750 751 752
        loss = self.get_attr("loss")
        assert loss is not None
        loss_op = loss.op
        loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
753 754
            loss_op
        )
J
JZ-LIANG 已提交
755 756

        if loss.dtype != core.VarDesc.VarType.FP32:
757 758

            tmp_name = unique_name.generate(loss.name + ".cast_fp32")
759 760 761
            cast_loss = main_block.create_var(
                name=tmp_name, dtype=core.VarDesc.VarType.FP32
            )
762
            loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
763 764
                loss
            )
765
            ref_mesh = loss_op_dist_attr.process_mesh
766
            self.dist_context.set_tensor_dist_attr_for_program(
767 768
                cast_loss, loss_dist_attr
            )
769

770
            # forward
771 772 773 774 775 776 777 778 779 780
            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
            cast_op = main_block._insert_op(
                loss_op_idx + 1,
                type='cast',
                inputs={'X': [loss]},
                outputs={'Out': [cast_loss]},
                attrs={
                    "in_dtype": loss.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                    'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
781 782
                },
            )
783

784 785 786
            loss_op._set_attr(
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
            )
787
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
788 789
                cast_op, ref_mesh, [-1], self.dist_context
            )
790 791 792

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
793 794 795 796
            assert (
                first_backward_op.type == "fill_constant"
                and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            )
797 798 799 800
            cast_loss_grad = main_block.create_var(
                name=unique_name.generate(tmp_name + "@GRAD"),
                shape=loss.shape,
                dtype=core.VarDesc.VarType.FP32,
801 802
                persistable=loss.persistable,
            )
803 804 805 806
            set_var_dist_attr(self.dist_context, cast_loss_grad, [-1], ref_mesh)

            pre_grad_name = first_backward_op.output_arg_names[0]
            first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
807 808 809
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                first_backward_op, ref_mesh, [-1], self.dist_context
            )
810 811 812 813 814 815 816 817 818
            cast_grad_op = main_block._insert_op(
                loss_op_idx + 3,
                type='cast',
                inputs={'X': [cast_loss_grad]},
                outputs={'Out': [pre_grad_name]},
                attrs={
                    "in_dtype": core.VarDesc.VarType.FP32,
                    "out_dtype": core.VarDesc.VarType.FP16,
                    'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
819 820
                },
            )
821
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
822 823
                cast_grad_op, ref_mesh, [-1], self.dist_context
            )
824 825
            loss_op = cast_op
            loss = cast_loss
J
JZ-LIANG 已提交
826

827 828 829 830
        if (
            self.get_attr("use_dynamic_loss_scaling")
            or self.get_attr("init_loss_scaling") != 1.0
        ):
J
JZ-LIANG 已提交
831 832 833 834 835 836 837 838 839

            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)

            # forward
            ref_mesh = loss_op_dist_attr.process_mesh
            self._scaled_loss = main_block.create_var(
                name=unique_name.generate("scaled_loss"),
                shape=loss.shape,
                dtype=loss.dtype,
840 841 842 843 844
                persistable=loss.persistable,
            )
            set_var_dist_attr(
                self.dist_context, self._scaled_loss, [-1], ref_mesh
            )
J
JZ-LIANG 已提交
845 846 847 848

            elementwise_mul_op = main_block._insert_op(
                loss_op_idx + 1,
                type='elementwise_mul',
849
                inputs={'X': [loss], 'Y': [self._loss_scaling]},
J
JZ-LIANG 已提交
850
                outputs={'Out': [self._scaled_loss]},
851 852
                attrs={
                    'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
853 854 855 856 857
                },
            )
            loss_op._set_attr(
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
            )
J
JZ-LIANG 已提交
858
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
859 860
                elementwise_mul_op, ref_mesh, [-1], self.dist_context
            )
J
JZ-LIANG 已提交
861 862 863

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
864 865 866 867
            assert (
                first_backward_op.type == "fill_constant"
                and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            )
J
JZ-LIANG 已提交
868 869 870 871
            self._scaled_loss_grad = main_block.create_var(
                name=unique_name.generate("scaled_loss") + "@GRAD",
                shape=loss.shape,
                dtype=loss.dtype,
872 873 874 875 876
                persistable=loss.persistable,
            )
            set_var_dist_attr(
                self.dist_context, self._scaled_loss_grad, [-1], ref_mesh
            )
J
JZ-LIANG 已提交
877
            pre_grad_name = first_backward_op.output_arg_names[0]
878 879 880
            first_backward_op._rename_output(
                pre_grad_name, self._scaled_loss_grad.name
            )
881 882 883
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                first_backward_op, ref_mesh, [-1], self.dist_context
            )
884
            self._scaled_loss_grad.op = first_backward_op
J
JZ-LIANG 已提交
885 886 887
            # FIXME(JZ-LIANG) a trick to insert backward op
            main_block._sync_with_cpp()
            elementwise_mul_grad_op_desc = main_block.desc._insert_op(
888 889
                loss_op_idx + 3
            )
J
JZ-LIANG 已提交
890 891
            elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
            elementwise_mul_grad_op_desc.set_input(
892 893
                'Out@GRAD', [self._scaled_loss_grad.name]
            )
J
JZ-LIANG 已提交
894
            elementwise_mul_grad_op_desc.set_input('X', [loss.name])
895 896 897
            elementwise_mul_grad_op_desc.set_input(
                'Y', [self._loss_scaling.name]
            )
J
JZ-LIANG 已提交
898 899 900
            elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
            elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
            elementwise_mul_grad_op_desc._set_attr(
901 902
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward
            )
J
JZ-LIANG 已提交
903
            elementwise_mul_grad_op_desc._set_attr('axis', -1)
904
            elementwise_mul_grad_op = paddle.static.Operator(
905 906
                main_block, elementwise_mul_grad_op_desc
            )
J
JZ-LIANG 已提交
907 908 909 910 911
            main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op)
            main_block._sync_with_cpp()
            elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
            assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
912 913
                elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context
            )
J
JZ-LIANG 已提交
914 915 916

        else:
            self._scaled_loss = loss
917
        self._loss = loss
J
JZ-LIANG 已提交
918 919 920 921 922 923 924
        main_block._sync_with_cpp()

    def _update_loss_scaling(self, grads, found_inf):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()

925 926 927 928 929 930
        check_variable_and_dtype(
            self._loss_scaling,
            "prev_loss_scaling",
            ['float32', 'float64'],
            "update_loss_scaling",
        )
J
JZ-LIANG 已提交
931 932
        check_type(grads, 'x', (tuple, list), 'update_loss_scaling')
        for e in grads:
933 934 935
            check_variable_and_dtype(
                e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
            )
936
            if e.dtype == core.VarDesc.VarType.FP16:
937 938 939
                assert (
                    self._loss_scaling.dtype == core.VarDesc.VarType.FP32
                ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
940
            else:
941 942 943
                assert (
                    self._loss_scaling.dtype == e.dtype
                ), "The dtype of prev_loss_scaling should be equal to the dtype of x."
J
JZ-LIANG 已提交
944 945 946 947 948 949

        inputs = {
            'X': grads,
            'FoundInfinite': found_inf,
            'PrevLossScaling': self._loss_scaling,
            'InGoodSteps': self._num_good_steps,
950
            'InBadSteps': self._num_bad_steps,
J
JZ-LIANG 已提交
951 952 953 954 955 956
        }

        outputs = {
            'Out': grads,
            'LossScaling': self._loss_scaling,
            'OutGoodSteps': self._num_good_steps,
957
            'OutBadSteps': self._num_bad_steps,
J
JZ-LIANG 已提交
958 959 960 961 962 963 964 965
        }

        attrs = {
            'incr_every_n_steps': self.get_attr("incr_every_n_steps"),
            'decr_every_n_nan_or_inf': self.get_attr("decr_every_n_nan_or_inf"),
            'incr_ratio': self.get_attr("incr_ratio"),
            'decr_ratio': self.get_attr("decr_ratio"),
            'stop_update': self.get_attr("stop_update"),
966
            'op_role': OpRole.Optimize,
J
JZ-LIANG 已提交
967 968
        }

969 970 971 972 973 974
        new_op = main_block.append_op(
            type='update_loss_scaling',
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
        )
J
JZ-LIANG 已提交
975

976 977 978 979
        # Constructing dist attr from op_desc can
        # give all inputs and outputs default dist attrs
        new_op_dist_attr = OperatorDistAttr(new_op.desc)
        new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
Z
zhaoyingli 已提交
980 981 982
        new_op_dist_attr.impl_idx = 0
        if len(world_process_group.ranks) > 1:
            new_op_dist_attr.impl_type = "update_loss_scaling"
J
JZ-LIANG 已提交
983 984 985
        for g in grads:
            g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
            assert g_dist_attr is not None
986 987 988 989 990 991
            new_op_dist_attr.set_input_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
            new_op_dist_attr.set_output_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
J
JZ-LIANG 已提交
992 993 994
        self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)

        main_block._sync_with_cpp()
995 996 997 998 999 1000 1001 1002 1003 1004

    def get_loss(self):
        # the amp / fp16 might change the effective loss variable for network and
        # therefore would affect the subsequent passes that rely on the loss.
        # return the effective loss after amp / fp16 pass.

        if self._loss:
            return self._loss
        else:
            return self.get_attr("loss")