auto_parallel_amp.py 39.0 KB
Newer Older
J
JZ-LIANG 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
J
JZ-LIANG 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
J
JZ-LIANG 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
J
JZ-LIANG 已提交
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16 17
from paddle.distributed.auto_parallel.dist_attribute import (
    OperatorDistributedAttribute,
18 19 20 21
)
from paddle.distributed.auto_parallel.process_group import (
    get_world_process_group,
)
22 23 24 25
from paddle.distributed.auto_parallel.utils import (
    get_loss_op,
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
    set_var_dist_attr,
26
)
27 28
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid import unique_name
29
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
30 31 32
    AutoMixedPrecisionLists,
    _dtype_to_str,
    _is_in_black_varnames,
33 34
    _keep_fp32_input,
    _keep_fp32_output,
35
    _rename_arg,
36
    _valid_types,
37
    find_op_index,
38 39 40
    find_true_post_op,
    find_true_prev_op,
)
41 42 43 44 45
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.framework import core

from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op
from .pass_base import PassBase, register_pass
46

Z
zhaoyingli 已提交
47
world_process_group = get_world_process_group()
J
JZ-LIANG 已提交
48 49


50
class AMPState:
J
JZ-LIANG 已提交
51 52
    def __init__(self, block):
        self._block = block
53 54 55
        self._op_fp16_dict = (
            {}
        )  # op_id --> True/False. 'True' means that the current op is in fp16 mode.
J
JZ-LIANG 已提交
56
        self._var_name_dict = {}  # fwd_op_id --> {old_name: cast_name}
Z
zhaoyingli 已提交
57
        self.is_train = False
J
JZ-LIANG 已提交
58 59 60 61

    def _is_fp16_op(self, op_id):
        return self._op_fp16_dict.get(op_id, None)

Z
zhaoyingli 已提交
62
    def _build_state(self, amp_lists, dist_context):
J
JZ-LIANG 已提交
63 64 65
        ops = self._block.ops
        dist_op_context = dist_context.dist_op_context
        for op in ops:
Z
zhaoyingli 已提交
66 67 68
            if int(op.attr('op_role')) == 257:
                self.is_train = True

J
JZ-LIANG 已提交
69 70 71
            if int(op.attr('op_role')) == int(OpRole.Forward):
                self._mark_black_white_ops(amp_lists)
            elif int(op.attr('op_role')) == int(OpRole.Backward):
72 73
                if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
                    fwd_op_id = dist_op_context.grad_op_id_to_op_id[
74 75
                        op.desc.original_id()
                    ]
76
                    if self._is_fp16_op(fwd_op_id) is True:
77
                        self._op_fp16_dict[op.desc.original_id()] = True
78
                    elif self._is_fp16_op(fwd_op_id) is False:
79
                        self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
80 81 82
            elif int(op.attr('op_role')) == int(OpRole.Optimize):
                break

Z
zhaoyingli 已提交
83 84
        return self.is_train

J
JZ-LIANG 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97
    def _mark_black_white_ops(self, amp_lists):
        """
        this function is modified from paddle.fluid.contrib.mixed_precision
        """
        self._block._sync_with_cpp()
        ops = self._block.ops

        for op in ops:
            if int(op.attr('op_role')) == int(OpRole.Backward):
                break
            if op.type == 'create_py_reader' or op.type == 'read':
                continue
            if amp_lists.black_varnames is not None and _is_in_black_varnames(
98 99
                op, amp_lists
            ):
100
                self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
101 102
                continue
            if op.type in amp_lists.black_list:
103
                self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
104
            elif op.type in amp_lists.white_list:
105
                self._op_fp16_dict[op.desc.original_id()] = True
J
JZ-LIANG 已提交
106 107 108 109 110 111 112 113 114 115 116 117
            elif op.type in amp_lists.gray_list:
                is_black_op = False
                is_white_op = False
                for in_name in op.input_names:
                    # if this op has inputs
                    if in_name:
                        for in_var_name in op.input(in_name):
                            in_var = self._block.var(in_var_name)
                            # this in_var isn't the output of other op
                            if in_var.op is None:
                                continue
                            elif in_var.op is op:
118
                                prev_op = find_true_prev_op(
119 120
                                    ops, op, in_var_name
                                )
J
JZ-LIANG 已提交
121 122 123 124 125
                                if prev_op is None:
                                    continue
                            else:
                                prev_op = in_var.op
                            # if it's one of inputs
126 127
                            if (
                                self._is_fp16_op(prev_op.desc.original_id())
128
                                is False
129 130
                                or prev_op.type in amp_lists.black_list
                            ):
J
JZ-LIANG 已提交
131
                                is_black_op = True
132 133
                            elif (
                                self._is_fp16_op(prev_op.desc.original_id())
134
                                is True
135 136
                                or prev_op.type in amp_lists.white_list
                            ):
J
JZ-LIANG 已提交
137 138
                                is_white_op = True
                if is_black_op:
139
                    self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
140
                elif is_white_op:
141
                    self._op_fp16_dict[op.desc.original_id()] = True
J
JZ-LIANG 已提交
142 143 144 145 146
                else:
                    pass
            else:
                # For numerical safe, we apply fp32 computation on ops that
                # are not determined which list they should stay.
147
                self._op_fp16_dict[op.desc.original_id()] = False
J
JZ-LIANG 已提交
148 149 150 151 152 153 154 155 156

    def cast_forward_program(self, dist_context):
        ops = self._block.ops
        idx = 0
        while idx < len(ops):
            op = ops[idx]
            num_cast_ops = 0
            if int(op.attr('op_role')) == int(OpRole.Backward):
                break
157
            if self._is_fp16_op(op.desc.original_id()) is False:
J
JZ-LIANG 已提交
158
                num_cast_ops = self._insert_cast_op_forward(
159 160 161 162 163 164
                    op,
                    idx,
                    core.VarDesc.VarType.FP16,
                    core.VarDesc.VarType.FP32,
                    dist_context,
                )
165
            elif self._is_fp16_op(op.desc.original_id()) is True:
J
JZ-LIANG 已提交
166
                num_cast_ops = self._insert_cast_op_forward(
167 168 169 170 171 172
                    op,
                    idx,
                    core.VarDesc.VarType.FP32,
                    core.VarDesc.VarType.FP16,
                    dist_context,
                )
J
JZ-LIANG 已提交
173 174 175 176 177
            else:
                pass
            idx += num_cast_ops + 1
        self._block._sync_with_cpp()

178 179 180
    def _insert_cast_op_forward(
        self, op, idx, src_dtype, dst_dtype, dist_context
    ):
J
JZ-LIANG 已提交
181 182 183 184 185
        """
        only for forward cast
        modified from paddle.fluid.contrib.mixed_precision
        """
        num_cast_ops = 0
186
        var_name_dict = {}
J
JZ-LIANG 已提交
187 188
        for in_name in op.input_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
189 190
                op, in_name
            ):
J
JZ-LIANG 已提交
191 192 193 194 195 196
                continue
            for in_var_name in op.input(in_name):
                in_var = self._block._find_var_recursive(in_var_name)
                if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
                    continue
                if in_var.dtype == src_dtype:
197 198 199
                    cast_name = (
                        in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
                    )
J
JZ-LIANG 已提交
200 201 202
                    out_var = self._block.vars.get(cast_name)
                    var_name_dict[in_var.name] = cast_name
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
203 204
                        op
                    )
J
JZ-LIANG 已提交
205 206 207 208 209
                    assert consume_op_attr is not None
                    if out_var is None or out_var.dtype != dst_dtype:
                        # NOTE we make the cast op and var's dist attr as the op that consume the
                        # cast var instead of the op which generates the var
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
210 211
                            in_var.name
                        )
J
JZ-LIANG 已提交
212 213 214
                        assert in_var_dist_attr is not None
                        ref_mesh = in_var_dist_attr.process_mesh
                        ref_mapping = in_var_dist_attr.dims_mapping
215
                        consume_op_attr.set_input_dist_attr(
216 217
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
218 219 220 221 222

                        out_var = self._block.create_var(
                            name=cast_name,
                            dtype=dst_dtype,
                            persistable=False,
223 224 225 226 227
                            stop_gradient=in_var.stop_gradient,
                        )
                        set_var_dist_attr(
                            dist_context, out_var, ref_mapping, ref_mesh
                        )
J
JZ-LIANG 已提交
228 229 230 231 232 233 234 235 236

                        cast_op = self._block._insert_op_without_sync(
                            idx,
                            type="cast",
                            inputs={"X": in_var},
                            outputs={"Out": out_var},
                            attrs={
                                "in_dtype": in_var.dtype,
                                "out_dtype": out_var.dtype,
237 238
                            },
                        )
J
JZ-LIANG 已提交
239
                        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
240 241
                            cast_op, ref_mesh, ref_mapping, dist_context
                        )
J
JZ-LIANG 已提交
242 243 244
                        num_cast_ops += 1
                    else:
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
245 246
                            in_var.name
                        )
247
                        consume_op_attr.set_input_dist_attr(
248 249
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
250 251 252 253
                    _rename_arg(op, in_var.name, cast_name)
                else:
                    if op.has_attr('in_dtype'):
                        op._set_attr('in_dtype', dst_dtype)
254
        self._var_name_dict[op.desc.original_id()] = var_name_dict
J
JZ-LIANG 已提交
255

256 257 258 259
        if (
            src_dtype == core.VarDesc.VarType.FP32
            and dst_dtype == core.VarDesc.VarType.FP16
        ):
J
JZ-LIANG 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
            for out_name in op.output_names:
                if _keep_fp32_output(op, out_name):
                    continue
                for out_var_name in op.output(out_name):
                    out_var = self._block.var(out_var_name)
                    if out_var.type not in _valid_types:
                        continue
                    if out_var.dtype == core.VarDesc.VarType.FP32:
                        out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
                        if op.has_attr('out_dtype'):
                            op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
        return num_cast_ops

    def cast_backward_program(self, params_grads, dist_context):
        self._block._sync_with_cpp()
        ops = self._block.ops

        loss_op = get_loss_op(self._block)
        loss_op_index = find_op_index(self._block.desc, loss_op.desc)

280
        appended_grad_times = 0
J
JZ-LIANG 已提交
281 282 283 284
        idx = loss_op_index + 1
        while idx < len(ops):
            num_cast_ops = 0
            grad_op = ops[idx]
285 286 287 288

            # NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
            # which will affect the dist_op to insert allreduce_sum op.
            op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
289 290 291
            if is_backward_op(grad_op) and (
                is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1])
            ):
292 293 294
                if not op_dist_attr.is_recompute:
                    appended_grad_times += 1

295
            grad_op_orig_id = grad_op.desc.original_id()
J
JZ-LIANG 已提交
296
            dist_op_context = dist_context.dist_op_context
297
            if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
298
                if self._is_fp16_op(grad_op_orig_id) is False:  # fp32
J
JZ-LIANG 已提交
299
                    num_cast_ops = self._insert_cast_op_backward(
300 301 302 303 304 305 306
                        grad_op,
                        idx,
                        core.VarDesc.VarType.FP16,
                        core.VarDesc.VarType.FP32,
                        dist_context,
                        appended_grad_times,
                    )
307
                elif self._is_fp16_op(grad_op_orig_id) is True:  # fp16
J
JZ-LIANG 已提交
308
                    num_cast_ops = self._insert_cast_op_backward(
309 310 311 312 313 314 315
                        grad_op,
                        idx,
                        core.VarDesc.VarType.FP32,
                        core.VarDesc.VarType.FP16,
                        dist_context,
                        appended_grad_times,
                    )
J
JZ-LIANG 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329
            elif grad_op.type == "sum":
                in_var_name = grad_op.desc.input_arg_names()[0]
                src_dtype = self._block.var(in_var_name).dtype
                for in_var_name in grad_op.desc.input_arg_names():
                    assert src_dtype == self._block.var(in_var_name).dtype
                out_var_name = grad_op.desc.output_arg_names()[0]
                out_var = self._block.var(out_var_name)
                if out_var.dtype != src_dtype:
                    out_var.desc.set_dtype(src_dtype)
            elif int(grad_op.attr('op_role')) == 257:
                pass
            else:
                raise ValueError(
                    "'{}' op is not supported in the complete amp pass.".format(
330 331 332
                        grad_op.type
                    )
                )
J
JZ-LIANG 已提交
333 334 335 336 337
            idx += num_cast_ops + 1

        self._block._sync_with_cpp()
        _update_backward_cast_ops(params_grads, dist_context)

338 339 340 341 342 343 344 345 346 347
    def _insert_cast_op_backward(
        self,
        grad_op,
        idx,
        src_dtype,
        dst_dtype,
        dist_context,
        appended_grad_times,
    ):
        """only for backward cast"""
J
JZ-LIANG 已提交
348 349 350 351 352 353 354 355 356 357 358 359 360 361

        def _keep_fp32_input(op, in_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return in_name not in {'X', 'Y@GRAD'}
            return False

        def _keep_fp32_output(op, out_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return out_name != 'X@GRAD'
            return False

        num_cast_ops = 0
362
        original_id = grad_op.desc.original_id()
J
JZ-LIANG 已提交
363
        dist_op_context = dist_context.dist_op_context
364
        fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
J
JZ-LIANG 已提交
365 366 367

        for in_name in grad_op.input_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
368 369
                grad_op, in_name
            ):
J
JZ-LIANG 已提交
370 371 372 373 374 375 376 377 378
                for in_var_name in grad_op.input(in_name):
                    in_var = self._block._find_var_recursive(in_var_name)
                    assert in_var.dtype == core.VarDesc.VarType.FP32
                continue

            for in_var_name in grad_op.input(in_name):
                in_var = self._block._find_var_recursive(in_var_name)
                if in_var.dtype == src_dtype:
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
379 380
                        grad_op
                    )
J
JZ-LIANG 已提交
381 382 383 384 385 386
                    if in_var_name in self._var_name_dict[fwd_op_id]:
                        # NOTE: if in_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr.
                        cast_name = self._var_name_dict[fwd_op_id][in_var_name]
                        grad_op.desc._rename_input(in_var_name, cast_name)
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
387 388
                            in_var_name
                        )
389
                        consume_op_attr.set_input_dist_attr(
390 391
                            cast_name, in_var_dist_attr
                        )
J
JZ-LIANG 已提交
392
                    else:
393 394 395 396 397 398 399 400 401
                        assert (
                            in_var.dtype == dst_dtype
                        ), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
                            grad_op.type,
                            in_name,
                            dst_dtype,
                            in_var.dtype,
                            str(grad_op),
                        )
J
JZ-LIANG 已提交
402 403 404

        for out_name in grad_op.output_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
405 406
                grad_op, out_name
            ):
J
JZ-LIANG 已提交
407 408 409 410 411 412 413
                for out_var_name in grad_op.output(out_name):
                    out_var = self._block._find_var_recursive(out_var_name)
                    assert out_var.dtype == core.VarDesc.VarType.FP32
                continue

            for out_var_name in grad_op.output(out_name):
                out_var = self._block._find_var_recursive(out_var_name)
414
                out_var_name_prefix = out_var_name[: out_var_name.find("@")]
J
JZ-LIANG 已提交
415 416 417 418 419 420 421 422 423 424
                fwd_var = self._block._find_var_recursive(out_var_name_prefix)
                # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
                if out_var.dtype != fwd_var.dtype:
                    out_var.desc.set_dtype(fwd_var.dtype)

                if out_var.dtype == src_dtype:
                    if out_var_name_prefix in self._var_name_dict[fwd_op_id]:
                        # NOTE: if out_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr, then we insert cast op to
                        # convert the cast_var to original dtype
425 426 427
                        consume_op_attr = (
                            dist_context.get_op_dist_attr_for_program(grad_op)
                        )
J
JZ-LIANG 已提交
428
                        fwd_cast_name = self._var_name_dict[fwd_op_id][
429 430
                            out_var_name_prefix
                        ]
431 432
                        suffix = ""
                        if "@RENAME" in out_var_name:
433 434 435
                            suffix = out_var_name[
                                out_var_name.find("@RENAME") :
                            ]
436
                        cast_name = fwd_cast_name + "@GRAD" + suffix
J
JZ-LIANG 已提交
437 438 439
                        cast_var = self._block.vars.get(cast_name)
                        if cast_var is None or cast_var.dtype != dst_dtype:
                            grad_op.desc._rename_output(out_var_name, cast_name)
440 441 442 443 444
                            out_var_dist_attr = (
                                consume_op_attr.get_output_dist_attr(
                                    out_var_name
                                )
                            )
J
JZ-LIANG 已提交
445 446 447
                            ref_mesh = out_var_dist_attr.process_mesh
                            ref_mapping = out_var_dist_attr.dims_mapping
                            consume_op_attr.set_output_dist_attr(
448 449
                                cast_name, out_var_dist_attr
                            )
J
JZ-LIANG 已提交
450 451 452 453 454 455
                            assert ref_mapping is not None
                            cast_var = self._block.create_var(
                                name=cast_name,
                                shape=out_var.shape,
                                dtype=dst_dtype,
                                persistable=False,
456 457 458 459 460
                                stop_gradient=out_var.stop_gradient,
                            )
                            set_var_dist_attr(
                                dist_context, cast_var, ref_mapping, ref_mesh
                            )
461
                            dist_op_context.grad_var_to_var[
462 463
                                appended_grad_times
                            ][cast_name] = fwd_cast_name
J
JZ-LIANG 已提交
464 465 466 467 468 469 470 471 472

                            cast_op = self._block._insert_op(
                                idx + 1,
                                type="cast",
                                inputs={"X": cast_var},
                                outputs={"Out": out_var},
                                attrs={
                                    "in_dtype": cast_var.dtype,
                                    "out_dtype": out_var.dtype,
473 474 475
                                    "op_role": OpRole.Backward,
                                },
                            )
J
JZ-LIANG 已提交
476 477 478 479
                            cast_op._remove_attr("op_role_var")
                            cast_op._remove_attr("op_namescope")
                            cast_op._remove_attr("with_quant_attr")
                            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
480 481
                                cast_op, ref_mesh, ref_mapping, dist_context
                            )
J
JZ-LIANG 已提交
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
                            num_cast_ops += 1
                else:
                    assert out_var.dtype == dst_dtype

        return num_cast_ops


def _update_backward_cast_ops(params_grads, dist_context):
    """
    move param grad cast to the end of backward segment
    in order to enabel fp16 allreduce
    """
    # TODO filter optimize ops in future

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    for p, g in params_grads:
        op = g.op
        if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
502 503 504
            if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
                'op_role_var'
            ):
J
JZ-LIANG 已提交
505 506 507 508
                op._remove_attr("op_role_var")

            post_ops = find_true_post_op(main_block.ops, op, g.name)
            if post_ops:
509 510 511 512 513
                raise ValueError(
                    "The cast op {0}'s output should not be"
                    "used by a non-optimize op, however, it"
                    "is used by {1}".format(op, post_ops[0])
                )
J
JZ-LIANG 已提交
514 515 516 517 518 519 520

            if op == main_block.ops[-1]:
                continue

            # add new op in the python and cpp at the same time
            new_op_desc = main_block.desc.append_op()
            new_op_desc.copy_from(op.desc)
521 522 523 524 525 526 527 528
            new_op = paddle.fluid.framework.Operator(
                block=main_block,
                desc=new_op_desc,
                type=None,
                inputs=None,
                outputs=None,
                attrs=None,
            )
J
JZ-LIANG 已提交
529 530 531 532 533
            main_block.ops.append(new_op)

            # dist attr
            param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
            output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
534 535
                main_block.var(op.output_arg_names[0])
            )
J
JZ-LIANG 已提交
536 537 538
            assert param_dist_attr is not None
            assert output_dist_attr is not None
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
539 540 541 542 543
                new_op,
                param_dist_attr.process_mesh,
                param_dist_attr.dims_mapping,
                dist_context,
            )
J
JZ-LIANG 已提交
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563

            output_dist_attr.process_mesh = param_dist_attr.process_mesh
            output_dist_attr.dims_mapping = param_dist_attr.dims_mapping

            op_idx = find_op_index(main_block.desc, op.desc)
            if op_idx == -1:
                raise ValueError("The op {0} is not in program".format(op))
            main_block._remove_op(op_idx, sync=False)

    main_block._sync_with_cpp()


def _check_and_update_gradient(params_grads, loss_scaling, dist_context):

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    grads = [g for _, g in params_grads]
    check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
    for e in grads:
564 565 566 567 568 569
        check_variable_and_dtype(
            e,
            "x",
            ['float16', 'float32', 'float64'],
            'check_finite_and_unscale',
        )
J
JZ-LIANG 已提交
570 571

    found_inf = main_block.create_var(
572 573 574
        name=unique_name.generate_with_ignorable_key(
            ".".join(['find_infinite_scale', 'tmp'])
        ),
J
JZ-LIANG 已提交
575 576 577 578
        shape=[1],
        dtype='bool',
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
579 580
        stop_gradient=False,
    )
Z
zhaoyingli 已提交
581
    set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
J
JZ-LIANG 已提交
582 583 584

    inputs = {'X': grads, 'Scale': loss_scaling}
    outputs = {'Out': grads, 'FoundInfinite': found_inf}
585
    attrs = {'op_role': OpRole.Optimize}
586 587 588 589 590 591
    new_op = main_block.append_op(
        type='check_finite_and_unscale',
        inputs=inputs,
        outputs=outputs,
        attrs=attrs,
    )
J
JZ-LIANG 已提交
592 593

    new_op_dist_attr = OperatorDistributedAttribute()
Z
zhaoyingli 已提交
594 595 596 597
    new_op_dist_attr.process_mesh = world_process_group.ranks
    new_op_dist_attr.impl_idx = 0
    if len(world_process_group.ranks) > 1:
        new_op_dist_attr.impl_type = "check_finite_and_unscale"
J
JZ-LIANG 已提交
598 599 600
    for g in grads:
        g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
        assert g_dist_attr is not None
601 602 603 604 605 606
        new_op_dist_attr.set_input_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
        new_op_dist_attr.set_output_dims_mapping(
            g.name, g_dist_attr.dims_mapping
        )
J
JZ-LIANG 已提交
607 608 609 610 611 612 613
    dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
    return grads, found_inf


@register_pass("auto_parallel_amp")
class AMPPass(PassBase):
    def __init__(self):
614
        super().__init__()
J
JZ-LIANG 已提交
615 616 617 618 619 620 621 622 623 624 625
        self.set_attr("loss", None)
        self.set_attr("dist_context", None)
        self.set_attr("custom_white_list", None)
        self.set_attr("custom_black_list", None)
        self.set_attr("custom_black_varnames", None)
        self.set_attr("init_loss_scaling", 32768.0)
        self.set_attr("incr_every_n_steps", 1000)
        self.set_attr("decr_every_n_nan_or_inf", 2)
        self.set_attr("incr_ratio", 2.0)
        self.set_attr("decr_ratio", 0.8)
        self.set_attr("use_dynamic_loss_scaling", False)
626
        self.set_attr("input_data", [])
J
JZ-LIANG 已提交
627
        self.set_attr("params_grads", [])
628
        self._loss = None
J
JZ-LIANG 已提交
629 630 631
        self._loss_scaling = None
        self._num_good_steps = None
        self._num_bad_steps = None
632
        self._loss = None
J
JZ-LIANG 已提交
633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652

    def _check_self(self):
        if self.get_attr("init_loss_scaling") < 0:
            return False
        if self.get_attr("incr_every_n_steps") < 0:
            return False
        if self.get_attr("decr_every_n_nan_or_inf") < 0:
            return False
        if self.get_attr("incr_ratio") < 0:
            return False
        if self.get_attr("decr_ratio") < 0:
            return False
        if self.get_attr("dist_context") is None:
            return False
        return True

    def _check_conflict(self, other_pass):

        return True

653 654
    # NOTE: why AMPBackwardPass can override apply_single_impl instead of
    # apply_impl? AMP is an optimization pass for serial program,
J
JZ-LIANG 已提交
655 656 657 658 659 660 661 662
    # in distributed scenario, all ranks should have the same modification.
    def _apply_single_impl(self, main_program, startup_program, context):
        self.dist_context = self.get_attr("dist_context")
        params_grads = self.get_attr("params_grads")

        amp_lists = AutoMixedPrecisionLists(
            set(self.get_attr("custom_white_list")),
            set(self.get_attr("custom_black_list")),
663 664
            set(self.get_attr("custom_black_varnames")),
        )
J
JZ-LIANG 已提交
665 666

        with paddle.static.program_guard(main_program, startup_program):
Z
zhaoyingli 已提交
667 668 669
            amp_state = AMPState(main_program.global_block())
            is_train = amp_state._build_state(amp_lists, self.dist_context)

J
JZ-LIANG 已提交
670
            amp_state.cast_forward_program(self.dist_context)
Z
zhaoyingli 已提交
671 672 673 674 675 676 677

        if is_train:
            with paddle.static.program_guard(main_program, startup_program):
                amp_state.cast_backward_program(params_grads, self.dist_context)
                self._init_amp_var()
                self._scale_loss()

678 679 680 681
                if (
                    self.get_attr("use_dynamic_loss_scaling")
                    or self.get_attr("init_loss_scaling") != 1.0
                ):
Z
zhaoyingli 已提交
682
                    grads, found_inf = _check_and_update_gradient(
683 684
                        params_grads, self._loss_scaling, self.dist_context
                    )
Z
zhaoyingli 已提交
685 686 687

                if self.get_attr("use_dynamic_loss_scaling"):
                    self._update_loss_scaling(grads, found_inf)
J
JZ-LIANG 已提交
688 689 690 691 692 693 694

    def _init_amp_var(self):
        self._loss_scaling = paddle.static.create_global_var(
            name=unique_name.generate("loss_scaling"),
            shape=[1],
            value=self.get_attr("init_loss_scaling"),
            dtype='float32',
695 696 697 698 699 700 701 702
            persistable=True,
        )
        set_var_dist_attr(
            self.dist_context,
            self._loss_scaling,
            [-1],
            world_process_group.ranks,
        )
J
JZ-LIANG 已提交
703 704 705 706 707 708 709

        if self.get_attr("use_dynamic_loss_scaling"):
            self._num_good_steps = paddle.static.create_global_var(
                name=unique_name.generate("num_good_steps"),
                shape=[1],
                value=0,
                dtype='int32',
710 711 712 713 714 715 716 717
                persistable=True,
            )
            set_var_dist_attr(
                self.dist_context,
                self._num_good_steps,
                [-1],
                world_process_group.ranks,
            )
J
JZ-LIANG 已提交
718 719 720 721 722 723

            self._num_bad_steps = paddle.static.create_global_var(
                name=unique_name.generate("num_bad_steps"),
                shape=[1],
                value=0,
                dtype='int32',
724 725 726 727 728 729 730 731
                persistable=True,
            )
            set_var_dist_attr(
                self.dist_context,
                self._num_bad_steps,
                [-1],
                world_process_group.ranks,
            )
J
JZ-LIANG 已提交
732 733 734 735 736

    def _scale_loss(self):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()
737 738
        OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()

J
JZ-LIANG 已提交
739 740 741 742
        loss = self.get_attr("loss")
        assert loss is not None
        loss_op = loss.op
        loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
743 744
            loss_op
        )
J
JZ-LIANG 已提交
745 746

        if loss.dtype != core.VarDesc.VarType.FP32:
747 748

            tmp_name = unique_name.generate(loss.name + ".cast_fp32")
749 750 751
            cast_loss = main_block.create_var(
                name=tmp_name, dtype=core.VarDesc.VarType.FP32
            )
752
            loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
753 754
                loss
            )
755
            ref_mesh = loss_op_dist_attr.process_mesh
756
            self.dist_context.set_tensor_dist_attr_for_program(
757 758
                cast_loss, loss_dist_attr
            )
759

760
            # forward
761 762 763 764 765 766 767 768 769 770
            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
            cast_op = main_block._insert_op(
                loss_op_idx + 1,
                type='cast',
                inputs={'X': [loss]},
                outputs={'Out': [cast_loss]},
                attrs={
                    "in_dtype": loss.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                    'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
771 772
                },
            )
773

774 775 776
            loss_op._set_attr(
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
            )
777
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
778 779
                cast_op, ref_mesh, [-1], self.dist_context
            )
780 781 782

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
783 784 785 786
            assert (
                first_backward_op.type == "fill_constant"
                and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            )
787 788 789 790
            cast_loss_grad = main_block.create_var(
                name=unique_name.generate(tmp_name + "@GRAD"),
                shape=loss.shape,
                dtype=core.VarDesc.VarType.FP32,
791 792
                persistable=loss.persistable,
            )
793 794 795 796 797 798 799 800 801 802 803 804 805
            set_var_dist_attr(self.dist_context, cast_loss_grad, [-1], ref_mesh)

            pre_grad_name = first_backward_op.output_arg_names[0]
            first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
            cast_grad_op = main_block._insert_op(
                loss_op_idx + 3,
                type='cast',
                inputs={'X': [cast_loss_grad]},
                outputs={'Out': [pre_grad_name]},
                attrs={
                    "in_dtype": core.VarDesc.VarType.FP32,
                    "out_dtype": core.VarDesc.VarType.FP16,
                    'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
806 807
                },
            )
808
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
809 810
                cast_grad_op, ref_mesh, [-1], self.dist_context
            )
811 812
            loss_op = cast_op
            loss = cast_loss
J
JZ-LIANG 已提交
813

814 815 816 817
        if (
            self.get_attr("use_dynamic_loss_scaling")
            or self.get_attr("init_loss_scaling") != 1.0
        ):
J
JZ-LIANG 已提交
818 819 820 821 822 823 824 825 826

            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)

            # forward
            ref_mesh = loss_op_dist_attr.process_mesh
            self._scaled_loss = main_block.create_var(
                name=unique_name.generate("scaled_loss"),
                shape=loss.shape,
                dtype=loss.dtype,
827 828 829 830 831
                persistable=loss.persistable,
            )
            set_var_dist_attr(
                self.dist_context, self._scaled_loss, [-1], ref_mesh
            )
J
JZ-LIANG 已提交
832 833 834 835

            elementwise_mul_op = main_block._insert_op(
                loss_op_idx + 1,
                type='elementwise_mul',
836
                inputs={'X': [loss], 'Y': [self._loss_scaling]},
J
JZ-LIANG 已提交
837
                outputs={'Out': [self._scaled_loss]},
838 839
                attrs={
                    'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
840 841 842 843 844
                },
            )
            loss_op._set_attr(
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
            )
J
JZ-LIANG 已提交
845
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
846 847
                elementwise_mul_op, ref_mesh, [-1], self.dist_context
            )
J
JZ-LIANG 已提交
848 849 850

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
851 852 853 854
            assert (
                first_backward_op.type == "fill_constant"
                and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            )
J
JZ-LIANG 已提交
855 856 857 858
            self._scaled_loss_grad = main_block.create_var(
                name=unique_name.generate("scaled_loss") + "@GRAD",
                shape=loss.shape,
                dtype=loss.dtype,
859 860 861 862 863
                persistable=loss.persistable,
            )
            set_var_dist_attr(
                self.dist_context, self._scaled_loss_grad, [-1], ref_mesh
            )
J
JZ-LIANG 已提交
864
            pre_grad_name = first_backward_op.output_arg_names[0]
865 866 867
            first_backward_op._rename_output(
                pre_grad_name, self._scaled_loss_grad.name
            )
J
JZ-LIANG 已提交
868 869 870
            # FIXME(JZ-LIANG) a trick to insert backward op
            main_block._sync_with_cpp()
            elementwise_mul_grad_op_desc = main_block.desc._insert_op(
871 872
                loss_op_idx + 3
            )
J
JZ-LIANG 已提交
873 874
            elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
            elementwise_mul_grad_op_desc.set_input(
875 876
                'Out@GRAD', [self._scaled_loss_grad.name]
            )
J
JZ-LIANG 已提交
877
            elementwise_mul_grad_op_desc.set_input('X', [loss.name])
878 879 880
            elementwise_mul_grad_op_desc.set_input(
                'Y', [self._loss_scaling.name]
            )
J
JZ-LIANG 已提交
881 882 883
            elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
            elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
            elementwise_mul_grad_op_desc._set_attr(
884 885
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward
            )
J
JZ-LIANG 已提交
886 887
            elementwise_mul_grad_op_desc._set_attr('axis', -1)
            elementwise_mul_grad_op = paddle.fluid.framework.Operator(
888 889
                main_block, elementwise_mul_grad_op_desc
            )
J
JZ-LIANG 已提交
890 891 892 893 894
            main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op)
            main_block._sync_with_cpp()
            elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
            assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
895 896
                elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context
            )
J
JZ-LIANG 已提交
897 898 899

        else:
            self._scaled_loss = loss
900
        self._loss = loss
J
JZ-LIANG 已提交
901 902 903 904 905 906 907
        main_block._sync_with_cpp()

    def _update_loss_scaling(self, grads, found_inf):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()

908 909 910 911 912 913
        check_variable_and_dtype(
            self._loss_scaling,
            "prev_loss_scaling",
            ['float32', 'float64'],
            "update_loss_scaling",
        )
J
JZ-LIANG 已提交
914 915
        check_type(grads, 'x', (tuple, list), 'update_loss_scaling')
        for e in grads:
916 917 918
            check_variable_and_dtype(
                e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
            )
919
            if e.dtype == core.VarDesc.VarType.FP16:
920 921 922
                assert (
                    self._loss_scaling.dtype == core.VarDesc.VarType.FP32
                ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
923
            else:
924 925 926
                assert (
                    self._loss_scaling.dtype == e.dtype
                ), "The dtype of prev_loss_scaling should be equal to the dtype of x."
J
JZ-LIANG 已提交
927 928 929 930 931 932

        inputs = {
            'X': grads,
            'FoundInfinite': found_inf,
            'PrevLossScaling': self._loss_scaling,
            'InGoodSteps': self._num_good_steps,
933
            'InBadSteps': self._num_bad_steps,
J
JZ-LIANG 已提交
934 935 936 937 938 939
        }

        outputs = {
            'Out': grads,
            'LossScaling': self._loss_scaling,
            'OutGoodSteps': self._num_good_steps,
940
            'OutBadSteps': self._num_bad_steps,
J
JZ-LIANG 已提交
941 942 943 944 945 946 947 948
        }

        attrs = {
            'incr_every_n_steps': self.get_attr("incr_every_n_steps"),
            'decr_every_n_nan_or_inf': self.get_attr("decr_every_n_nan_or_inf"),
            'incr_ratio': self.get_attr("incr_ratio"),
            'decr_ratio': self.get_attr("decr_ratio"),
            'stop_update': self.get_attr("stop_update"),
949
            'op_role': OpRole.Optimize,
J
JZ-LIANG 已提交
950 951
        }

952 953 954 955 956 957
        new_op = main_block.append_op(
            type='update_loss_scaling',
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
        )
J
JZ-LIANG 已提交
958 959

        new_op_dist_attr = OperatorDistributedAttribute()
Z
zhaoyingli 已提交
960 961 962 963
        new_op_dist_attr.process_mesh = world_process_group.ranks
        new_op_dist_attr.impl_idx = 0
        if len(world_process_group.ranks) > 1:
            new_op_dist_attr.impl_type = "update_loss_scaling"
J
JZ-LIANG 已提交
964 965 966
        for g in grads:
            g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
            assert g_dist_attr is not None
967 968 969 970 971 972
            new_op_dist_attr.set_input_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
            new_op_dist_attr.set_output_dims_mapping(
                g.name, g_dist_attr.dims_mapping
            )
J
JZ-LIANG 已提交
973 974 975
        self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)

        main_block._sync_with_cpp()
976 977 978 979 980 981 982 983 984 985

    def get_loss(self):
        # the amp / fp16 might change the effective loss variable for network and
        # therefore would affect the subsequent passes that rely on the loss.
        # return the effective loss after amp / fp16 pass.

        if self._loss:
            return self._loss
        else:
            return self.get_attr("loss")