utils.py 34.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
J
JZ-LIANG 已提交
14
import paddle
15
from paddle.fluid import core, unique_name
16
from functools import reduce
17 18 19 20 21
from paddle.distributed.fleet.meta_optimizers.common import (
    is_loss_grad_op,
    is_backward_op,
    is_optimizer_op,
)
22
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
23 24

import re
J
JZ-LIANG 已提交
25
import os
26 27 28 29 30 31 32 33 34


def check_broadcast(block):
    """
    if a var is broadcasted, it should have a sync_comm before
    this var is used, if not, raise error.
    if the broadcasted var has a fill_constant op, the fill_constant
    op should stay forward before the broadcast op, and before a
    sync_calc op. Otherwise, raise error.
35 36

    should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
37 38 39 40
    """
    broadcast_vars = {}
    for idx, op in enumerate(block.ops):
        if op.type == "c_broadcast":
41
            if not op.all_attrs()["use_calc_stream"]:
42 43 44
                var_name = op.desc.input_arg_names()[0]
                if "@BroadCast" in var_name:
                    if var_name in broadcast_vars:
45 46 47 48
                        raise ValueError(
                            "var_name areadly exist: {}"
                            "the old pos is {}, the new pos is {}".format(
                                var_name,
49 50 51 52
                                broadcast_vars[var_name]["broadcast_pos"],
                                idx,
                            )
                        )
53 54 55 56
                    broadcast_vars[var_name] = {
                        "fill_constant_pos": -1,
                        "broadcast_pos": idx,
                    }
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

    for idx, op in enumerate(block.ops):
        if op.type == "fill_constant":
            var_name = op.desc.output_arg_names()[0]
            if var_name in broadcast_vars:
                broadcast_vars[var_name]["fill_constant_pos"] = idx
            continue

    last_sync_comm_op_idx = -1
    last_sync_calc_op_idx = -1
    for idx, op in enumerate(block.ops):
        if op.type == "c_sync_comm_stream":
            last_sync_comm_op_idx = idx
            continue
        if op.type == "c_sync_calc_stream":
            last_sync_calc_op_idx = idx
            continue
        if op.type == "c_broadcast":
75
            if not op.all_attrs()["use_calc_stream"]:
76 77 78
                var_name = op.desc.input_arg_names()[0]
                if "@BroadCast" in var_name:
                    if broadcast_vars[var_name]["fill_constant_pos"] != -1:
79 80 81 82 83 84
                        assert last_sync_calc_op_idx != -1
                        assert (
                            broadcast_vars[var_name]["fill_constant_pos"]
                            < last_sync_calc_op_idx
                        )
                        assert last_sync_calc_op_idx < idx
85
                    continue
86 87
        for input_name in op.desc.input_arg_names():
            if input_name in broadcast_vars:
88 89 90 91 92 93
                assert broadcast_vars[input_name]["broadcast_pos"] != -1
                assert (
                    broadcast_vars[input_name]["broadcast_pos"]
                    < last_sync_comm_op_idx
                )
                assert last_sync_comm_op_idx < idx
94 95 96
    return


97
def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
98
    """
99 100 101 102
    the op order should be:
        grad:
            - 0: op that generate Var
            - 1: sync_calc
103
            - 2: reduce_sum_sharding (allreduce --> reduce)
104 105 106 107
            - 3: sync_comm
            - 4: allreuce_sum_dp (dp_grads)
            - 5: sync_comm (dp_grads)
            - 6: op that use Var (dp_grads & sum)
108 109

    should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
110
    """
111 112 113 114 115
    vars_status = {}
    dp_grads_status = {}
    idx_last_grad_allreduce = -1
    idx_amp_allreduce = -1
    idx_gradient_clip_allreduce = -1
116

117
    for idx, op in enumerate(block.ops):
118 119
        # sharding use both allreduce and reduce to sync grad
        if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
120
            if not op.all_attrs()["use_calc_stream"]:
121 122 123
                ring_id = op.desc.attr("ring_id")
                var_name = op.desc.input_arg_names()[0]
                param = var_name.split("@")[0]
124

125 126 127 128 129
                assert 'sum' in var_name or ("@GRAD" in var_name)
                if 'sum' in var_name or (not shard.has_param(param)):
                    vars_status[var_name] = -1
                else:
                    dp_grads_status[var_name] = -1
130

131 132 133
                if ring_id != sharding_ring_id:
                    assert shard.has_param(param)
                    assert ring_id == dp_ring_id
134

135 136 137 138
                if "sum" in var_name:
                    idx_amp_allreduce = idx
                elif "@GRAD":
                    idx_last_grad_allreduce = idx
139 140 141

        if op.type == "c_allreduce_max":
            idx_gradient_clip_allreduce = idx
142 143 144

    for op in block.ops:
        if op.type == "c_sync_calc_stream":
145 146 147 148
            for var_name in vars_status:
                if var_name in vars_status and vars_status[var_name] == 0:
                    vars_status[var_name] = 1
            for var_name in dp_grads_status:
149 150 151 152
                if (
                    var_name in dp_grads_status
                    and dp_grads_status[var_name] == 0
                ):
153
                    dp_grads_status[var_name] = 1
154 155
        # check sharding allreduce and  reduce but skip megatron allreduce
        elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
156
            if not op.all_attrs()["use_calc_stream"]:
157 158 159
                var_name = op.desc.input_arg_names()[0]
                ring_id = op.desc.attr("ring_id")
                if ring_id == sharding_ring_id:
160 161 162
                    assert (
                        op.type == "c_reduce_sum"
                    ), "Grad in Sharding group should be reduce rather than allreduce"
163 164 165 166 167
                    if var_name in vars_status:
                        _status = vars_status[var_name]
                    else:
                        _status = dp_grads_status[var_name]
                    if _status == -1:
168 169
                        raise ValueError(
                            "{} is not generated, but you are"
170 171
                            "trying to all-reduce it".format(var_name)
                        )
172
                    if _status == 0:
173 174 175 176 177 178
                        raise ValueError(
                            "There should be a sync_calc op "
                            "after generate Var: {} and before the"
                            "c_allreduce_sum op".format(var_name)
                        )
                    assert _status == 1
179 180 181 182
                    if var_name in vars_status:
                        vars_status[var_name] = 2
                    else:
                        dp_grads_status[var_name] = 2
183
                else:
184 185 186 187 188
                    assert ring_id == dp_ring_id
                    param = var_name.split("@")[0]
                    assert shard.has_param(param)
                    assert dp_grads_status[var_name] == 3
                    dp_grads_status[var_name] = 4
189

190
        elif op.type == "c_sync_comm_stream":
191 192
            var_name = op.desc.input_arg_names()[0]
            ring_id = op.desc.attr("ring_id")
193
            if ring_id == sharding_ring_id:
194 195 196 197 198 199 200 201 202 203 204 205 206 207
                for var_name in op.desc.input_arg_names():
                    if var_name in vars_status:
                        assert vars_status[var_name] == 2
                        vars_status[var_name] = 3
                    elif var_name in dp_grads_status:
                        assert dp_grads_status[var_name] == 2
                        dp_grads_status[var_name] = 3
            else:
                for var_name in op.desc.input_arg_names():
                    param = var_name.split("@")[0]
                    assert ring_id == dp_ring_id
                    assert shard.has_param(param)
                    assert dp_grads_status[var_name] == 4
                    dp_grads_status[var_name] = 5
208 209
        else:
            for input_name in op.desc.input_arg_names():
210 211
                if input_name in vars_status:
                    if vars_status[input_name] != 3:
212 213
                        raise ValueError(
                            "There should be a sync_comm op "
214 215
                            "after allreduce the Var: {}".format(input_name)
                        )
216
                    raise ValueError(
217 218 219 220
                        "The reduce output grad [{}] should NOT be be used in Non-root rank.".format(
                            input_name
                        )
                    )
221 222 223
                if input_name in dp_grads_status:
                    if dp_ring_id == -1:
                        if dp_grads_status[input_name] != 3:
224 225
                            raise ValueError(
                                "There should be a sync_comm op "
226 227
                                "after allreduce the Var: {}".format(input_name)
                            )
228 229 230 231
                    else:
                        if dp_grads_status[input_name] != 5:
                            raise ValueError(
                                "The grad in shard should be allreduce and sync"
232 233
                                "twice before usage {}".format(input_name)
                            )
234

235
            for output_name in op.desc.output_arg_names():
236 237 238 239
                if (
                    output_name in vars_status
                    and vars_status[output_name] == -1
                ):
240
                    vars_status[output_name] = 0
241 242 243 244
                if (
                    output_name in dp_grads_status
                    and dp_grads_status[output_name] == -1
                ):
245 246 247 248 249 250 251 252 253 254
                    dp_grads_status[output_name] = 0

    # check sharding with amp
    if idx_amp_allreduce != -1:
        assert idx_amp_allreduce > idx_last_grad_allreduce

    # check sharding with gradient_clip_by_global_norm
    if idx_gradient_clip_allreduce != -1:
        assert idx_gradient_clip_allreduce > idx_last_grad_allreduce

255 256 257
    return


J
JZ-LIANG 已提交
258 259 260 261 262
def get_valid_op_role(block, insert_idx):
    """
    return OpRole.Forward or OpRole.Backward
    """
    op_role = block.ops[insert_idx].attr('op_role')
263 264 265
    if (insert_idx >= len(block.ops)) or (
        op_role in [int(OpRole.Backward), int(OpRole.Optimize)]
    ):
J
JZ-LIANG 已提交
266 267 268 269 270 271 272
        return OpRole.Backward
    if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
        return OpRole.Forward

    return get_valid_op_role(block, insert_idx + 1)


273 274 275 276
def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
    """
    _insert_sync_calc_op
    """
J
JZ-LIANG 已提交
277
    op_role = get_valid_op_role(block, insert_idx)
278 279 280 281 282 283 284
    block._insert_op_without_sync(
        insert_idx,
        type='c_sync_calc_stream',
        inputs={'X': calc_dep_vars},
        outputs={'Out': calc_dep_vars},
        attrs={OP_ROLE_KEY: op_role},
    )
285 286 287
    return


288
def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars):
289
    """
290
    insert sync_comm_op for single var
291
    """
J
JZ-LIANG 已提交
292
    op_role = get_valid_op_role(block, insert_idx)
293 294 295 296 297 298 299
    block._insert_op_without_sync(
        insert_idx,
        type='c_sync_comm_stream',
        inputs={'X': comm_dep_vars},
        outputs={'Out': comm_dep_vars},
        attrs={'ring_id': ring_id, OP_ROLE_KEY: op_role},
    )
300 301 302 303 304 305 306
    return 1


def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
    """
    insert sync_comm_op for vars
    """
307
    # NOTE (JZ-LIANG) to be check, may result undefined case
308 309 310
    if len(comm_dep_vars) == 0:
        return 0

311
    op_role = get_valid_op_role(block, insert_idx)
312 313 314 315 316 317 318
    block._insert_op_without_sync(
        insert_idx,
        type='c_sync_comm_stream',
        inputs={'X': comm_dep_vars},
        outputs={'Out': comm_dep_vars},
        attrs={'ring_id': int(ring_id), OP_ROLE_KEY: op_role},
    )
319
    return 1
320 321 322 323 324 325


def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
    """
    _add_fill_constant_ops
    """
J
JZ-LIANG 已提交
326
    op_role = get_valid_op_role(block, insert_idx)
327 328
    for broadcast_name in fill_constant_vars:
        broadcast_var = block.var(broadcast_name)
329 330 331 332 333 334 335 336 337 338 339
        block._insert_op_without_sync(
            insert_idx,
            type="fill_constant",
            outputs={"Out": broadcast_var.name},
            attrs={
                "shape": broadcast_var.shape,
                "dtype": broadcast_var.dtype,
                "value": 0.0,
                OP_ROLE_KEY: op_role,
            },
        )
340 341 342 343 344 345 346
    return


def insert_cast_ops(block, insert_idx, cast_ops):
    """
    _add_cast_ops
    """
J
JZ-LIANG 已提交
347
    op_role = get_valid_op_role(block, insert_idx)
348
    for fp16_name, fp32_name in cast_ops.items():
349 350 351 352 353 354 355 356 357 358 359
        block._insert_op_without_sync(
            insert_idx,
            type="cast",
            inputs={"X": fp32_name},
            outputs={"Out": fp16_name},
            attrs={
                "in_dtype": core.VarDesc.VarType.FP32,
                "out_dtype": core.VarDesc.VarType.FP16,
                OP_ROLE_KEY: op_role,
            },
        )
360 361 362
    return


363 364 365 366 367 368 369 370 371
def insert_allreduce_ops(
    block,
    insert_idx,
    ring_id,
    allreduce_vars,
    op_role=OpRole.Backward,
    use_calc_stream=False,
    user_defined_strategy=None,
):
372 373 374
    """
    _add_allreduce_ops
    """
375 376 377
    if len(allreduce_vars) == 0:
        return

378 379 380 381 382
    if (
        user_defined_strategy
        and user_defined_strategy.fuse_all_reduce_ops
        and not user_defined_strategy.fuse_grad_merge
    ):
383 384
        # If fuse_grad_merge is enable, the grad vars have already been fused during
        # gradient merge pass, therefore, those vars are not need to be fused here
385 386 387 388 389 390 391 392 393
        insert_fused_allreduce_ops(
            block,
            insert_idx,
            ring_id,
            allreduce_vars,
            op_role,
            use_calc_stream,
            user_defined_strategy.fuse_grad_size_in_MB,
        )
394 395
    else:
        for var in allreduce_vars:
396 397 398 399 400 401 402 403 404 405 406
            block._insert_op_without_sync(
                insert_idx,
                type='c_allreduce_sum',
                inputs={'X': var},
                outputs={'Out': var},
                attrs={
                    'ring_id': ring_id,
                    'use_calc_stream': use_calc_stream,
                    OP_ROLE_KEY: op_role,
                },
            )
407 408 409 410

    return


411
class FuseHelper(object):
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
    @staticmethod
    def sort_vars_by_dtype(block, vars_name):
        fp32_vars = []
        fp16_vars = []
        other_vars = []
        for var in vars_name:
            dtype = block.var(var).dtype
            if dtype == paddle.float32:
                fp32_vars.append(var)
            elif dtype == paddle.float16:
                fp16_vars.append(var)
            else:
                other_vars.append(var)
        assert len(other_vars) == 0, "only support fp32/fp16 vars for fuse"

        fp32_vars.extend(fp16_vars)
        return fp32_vars

430
    @staticmethod
431 432
    def get_fused_groups(block, vars_name, fuse_size=32.0):
        """coalesce tensor, get fused group"""
433
        groups = []
434
        cur_size = 0.0
435 436 437 438
        last_dtype = None
        for var_name in vars_name:
            real_var = block.var(var_name)
            var_size = get_var_size(real_var)
439 440 441 442 443
            if (
                cur_size + var_size > fuse_size
                or len(groups) == 0
                or real_var.dtype != last_dtype
            ):
444 445 446 447 448 449 450 451 452
                groups.append([real_var])
                cur_size = var_size
                last_dtype = real_var.dtype
            else:
                groups[-1].append(real_var)
                cur_size += var_size
        return groups

    @staticmethod
453 454 455
    def insert_coalesce_tensor(
        block, index, groups, op_role=OpRole.Backward, prefix="Output"
    ):
456 457 458 459 460 461 462 463 464
        fused_vars = []
        insert_num = 0
        for group in groups:
            assert len(group) >= 1
            if len(group) == 1:
                # no need fuse
                fused_vars.append(group[0])
                continue

465 466 467 468 469 470 471 472
            fused_var = block.create_var(
                name=unique_name.generate(
                    'Fused{}_{}'.format(prefix, group[0].name)
                ),
                dtype=group[0].dtype,
                persistable=False,
                stop_gradient=True,
            )
473
            fused_vars.append(fused_var)
474 475 476 477 478 479 480 481 482 483 484 485
            block._insert_op_without_sync(
                index,
                type="coalesce_tensor",
                inputs={"Input": group},
                outputs={"Output": group, "FusedOutput": fused_var},
                attrs={
                    "copy_data": True,
                    "use_align": True,
                    "dtype": group[0].dtype,
                    OP_ROLE_KEY: op_role,
                },
            )
486 487 488 489
            insert_num += 1
        return fused_vars, insert_num


490 491 492 493 494 495 496 497 498 499 500 501
def insert_fused_allreduce_ops(
    block,
    insert_idx,
    ring_id,
    allreduce_vars,
    op_role=OpRole.Backward,
    use_calc_stream=False,
    fuse_grad_size_in_MB=32,
):
    groups = FuseHelper.get_fused_groups(
        block, allreduce_vars, fuse_grad_size_in_MB
    )
502

503 504 505
    fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
        block, insert_idx, groups, op_role, prefix="Grad"
    )
506 507

    for fused_var in fused_vars:
508 509 510 511 512 513 514 515 516 517 518
        block._insert_op_without_sync(
            insert_idx + insert_num,
            type='c_allreduce_sum',
            inputs={'X': fused_var},
            outputs={'Out': fused_var},
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
                OP_ROLE_KEY: op_role,
            },
        )
519
        if not use_calc_stream:
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
            block._insert_op_without_sync(
                insert_idx + insert_num,
                type='c_sync_calc_stream',
                inputs={'X': fused_var},
                outputs={'Out': fused_var},
                attrs={OP_ROLE_KEY: op_role},
            )


def insert_fused_reduce_ops(
    block,
    insert_idx,
    ring_id,
    reduce_vars,
    shard,
    op_role=OpRole.Backward,
    use_calc_stream=False,
    rank=None,
    fuse_grad_size=32,
):
540 541 542 543 544
    nranks = shard.worker_num
    device_to_vars = [[] for _ in range(nranks)]

    for var in reduce_vars:
        root_id = get_grad_device(var, shard)
545 546 547 548 549 550
        assert 0 <= root_id < nranks, (
            "root_id should >=0 and < nranks, "
            "but now nranks={}, the root_id of var={} is {}".format(
                nranks, var, root_id
            )
        )
551 552 553 554 555 556
        device_to_vars[root_id].append(var)

    for root_id, vars_name in enumerate(device_to_vars):
        groups = FuseHelper.get_fused_groups(block, vars_name, fuse_grad_size)

        fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
557 558
            block, insert_idx, groups, op_role, prefix="Grad"
        )
559 560

        for fused_var in fused_vars:
561 562 563 564 565 566 567 568 569 570 571 572
            block._insert_op_without_sync(
                insert_idx + insert_num,
                type='c_reduce_sum',
                inputs={'X': fused_var},
                outputs={'Out': fused_var},
                attrs={
                    'ring_id': ring_id,
                    'root_id': root_id,
                    'use_calc_stream': use_calc_stream,
                    OP_ROLE_KEY: op_role,
                },
            )
573
            if not use_calc_stream:
574 575 576 577 578 579 580
                block._insert_op_without_sync(
                    insert_idx + insert_num,
                    type='c_sync_calc_stream',
                    inputs={'X': fused_var},
                    outputs={'Out': fused_var},
                    attrs={OP_ROLE_KEY: op_role},
                )
581 582 583 584

    return [] if rank is None else device_to_vars[rank]


585 586 587 588 589 590 591 592 593 594 595
def insert_reduce_ops(
    block,
    insert_idx,
    ring_id,
    reduce_vars,
    shard,
    op_role=OpRole.Backward,
    use_calc_stream=False,
    rank=None,
    strategy=None,
):
596
    """
597
    _add_reduce_ops
598
    """
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
    if (
        strategy
        and strategy.fuse_all_reduce_ops
        and not strategy.fuse_grad_merge
    ):
        return insert_fused_reduce_ops(
            block,
            insert_idx,
            ring_id,
            reduce_vars,
            shard,
            op_role,
            use_calc_stream,
            rank,
            strategy.fuse_grad_size_in_MB,
        )
615

616
    grad_in_this_device = []
617
    for var in reduce_vars:
618
        grad_var = var
619 620 621 622 623
        if (
            strategy
            and strategy.fuse_all_reduce_ops
            and strategy.fuse_grad_merge
        ):
624 625 626 627
            # TODO(wangxi): if support fp16_allreduce, need be
            # 'FusedMergedGrad.cast_fp16._'
            grad_var = var.replace('FusedMergedGrad_', '')
        root_id = get_grad_device(grad_var, shard)
628 629 630 631 632
        assert (
            root_id >= 0
        ), "root id should be a positive int, but now root id is {}".format(
            root_id
        )
633 634
        if rank is not None and rank == root_id:
            grad_in_this_device.append(var)
635 636 637 638 639 640 641 642 643 644 645 646
        block._insert_op_without_sync(
            insert_idx,
            type='c_reduce_sum',
            inputs={'X': var},
            outputs={'Out': var},
            attrs={
                'ring_id': ring_id,
                'root_id': root_id,
                'use_calc_stream': use_calc_stream,
                OP_ROLE_KEY: op_role,
            },
        )
647 648

    return grad_in_this_device
649 650


651 652 653 654 655 656 657 658 659 660 661
def insert_fused_broadcast_param_ops(
    block,
    insert_idx,
    ring_id,
    params,
    shard,
    op_role=OpRole.Optimize,
    use_calc_stream=False,
    rank=None,
    fuse_size=32,
):
662 663 664 665 666
    nranks = shard.worker_num
    device_to_vars = [[] for _ in range(nranks)]

    for var in params:
        root_id = shard.device(var)
667 668 669 670 671 672
        assert 0 <= root_id < nranks, (
            "root_id should >=0 and < nranks, "
            "but now nranks={}, the root_id of var={} is {}".format(
                nranks, var, root_id
            )
        )
673 674 675 676 677 678
        device_to_vars[root_id].append(var)

    for root_id, vars_name in enumerate(device_to_vars):
        groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size)

        fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
679 680
            block, insert_idx, groups, op_role, prefix="Param"
        )
681 682

        for fused_var in fused_vars:
683 684 685 686 687 688 689 690 691 692 693 694
            block._insert_op_without_sync(
                insert_idx + insert_num,
                type='c_broadcast',
                inputs={'X': fused_var},
                outputs={'Out': fused_var},
                attrs={
                    'ring_id': ring_id,
                    'root': root_id,
                    'use_calc_stream': use_calc_stream,
                    OP_ROLE_KEY: op_role,
                },
            )
695
            if not use_calc_stream:
696 697 698 699 700 701 702
                block._insert_op_without_sync(
                    insert_idx + insert_num,
                    type='c_sync_calc_stream',
                    inputs={'X': fused_var},
                    outputs={'Out': fused_var},
                    attrs={OP_ROLE_KEY: op_role},
                )
703 704 705 706

    return [] if rank is None else device_to_vars[rank]


707 708 709 710 711 712 713 714 715 716 717
def insert_broadcast_param_ops(
    block,
    insert_idx,
    ring_id,
    params,
    shard,
    op_role=OpRole.Optimize,
    use_calc_stream=False,
    rank=None,
    strategy=None,
):
718 719 720 721 722
    """
    add broadcast param ops
    """
    if strategy and strategy.fuse_all_reduce_ops:
        # TODO(wangxi): put fused var in startup_program, only need exec once
723 724 725 726 727 728 729 730 731 732 733
        return insert_fused_broadcast_param_ops(
            block,
            insert_idx,
            ring_id,
            params,
            shard,
            op_role,
            use_calc_stream,
            rank,
            strategy.fuse_grad_size_in_MB,
        )
734 735 736 737

    param_in_this_device = []
    for param in params:
        root_id = shard.device(param)
738 739 740 741 742
        assert (
            root_id >= 0
        ), "root id should be a positive int, but now root id is {}".format(
            root_id
        )
743 744
        if rank is not None and rank == root_id:
            param_in_this_device.append(param)
745 746 747 748 749 750 751 752 753 754 755 756
        block._insert_op_without_sync(
            insert_idx,
            type='c_broadcast',
            inputs={'X': param},
            outputs={'Out': param},
            attrs={
                'ring_id': ring_id,
                'root': root_id,
                'use_calc_stream': use_calc_stream,
                OP_ROLE_KEY: op_role,
            },
        )
757 758 759 760

    return param_in_this_device


761 762 763
def fuse_opt_broadcast_param_ops(
    block, ring_id, shard, op_role=OpRole.Optimize, strategy=None
):
764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
    """
    fuse optimizer sharding broadcast param ops
    """
    if strategy is None or not strategy.fuse_all_reduce_ops:
        return

    fuse_size = strategy.fuse_grad_size_in_MB

    nranks = shard.worker_num
    device_to_vars = [[] for _ in range(nranks)]

    for idx, op in reversed(list(enumerate(block.ops))):
        if not is_optimizer_op(op) or op.type != 'c_broadcast':
            break
        var = op.input_arg_names[0]
        root_id = op.attr('root')
        device_to_vars[root_id].insert(0, var)
        block._remove_op(idx, sync=False)

    insert_idx = idx + 1
    for root_id, vars_name in enumerate(device_to_vars):
        vars_name = FuseHelper.sort_vars_by_dtype(block, vars_name)
        groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size)

        fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
789 790
            block, insert_idx, groups, op_role, prefix="Param"
        )
791 792

        for fused_var in fused_vars:
793 794 795 796 797 798 799 800 801 802 803 804
            block._insert_op_without_sync(
                insert_idx + insert_num,
                type='c_broadcast',
                inputs={'X': fused_var},
                outputs={'Out': fused_var},
                attrs={
                    'ring_id': ring_id,
                    'root': root_id,
                    'use_calc_stream': True,
                    OP_ROLE_KEY: op_role,
                },
            )
805 806 807 808

    block._sync_with_cpp()


809 810
def get_grad_device(grad_name, shard):
    assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
811 812
        grad_name
    )
813
    base_name = None
814
    # NOTE: mind the traversal order
815
    possible_suffixes = [
816 817 818 819 820 821 822
        # sharding gm
        '.cast_fp16@GRAD@MERGED',
        '.cast_fp16@GRAD',
        # pipeline
        '@GRAD@MERGED@FP16',
        '@GRAD@MERGED',
        '@GRAD',
823 824 825 826 827 828
    ]
    for suffix in possible_suffixes:
        if suffix in grad_name:
            base_name = re.sub(suffix, '', grad_name)
            break

829 830 831
    assert (
        base_name in shard.global_param2device
    ), "[{}] should be a param variable.".format(base_name)
832 833 834 835

    return shard.global_param2device[base_name]


B
Baibaifan 已提交
836
def get_first_check_finite_and_unscale_op_idx(block, raise_error=True):
837 838 839 840 841

    for idx, op in enumerate(block.ops):
        if op.type == "check_finite_and_unscale":
            return idx

B
Baibaifan 已提交
842 843 844 845 846 847
    if raise_error:
        raise ValueError(
            "amp is turned on but check_finite_and_unscale op does not exist in main block"
        )

    return -1
848 849


850 851 852 853 854 855 856 857 858
def get_first_optimize_op_idx(block):
    first_opt_op_idx = None
    for index, op in reversed(tuple(enumerate(block.ops))):
        if is_backward_op(op) and first_opt_op_idx is None:
            first_opt_op_idx = index + 1
            break
    return first_opt_op_idx


859
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
860 861 862
    """
    _add_broadcast_ops
    """
J
JZ-LIANG 已提交
863
    op_role = get_valid_op_role(block, insert_idx)
864
    for broadcast_name, root_device in broadcast2root:
865 866 867 868 869 870 871 872 873 874 875
        block._insert_op_without_sync(
            insert_idx,
            type='c_broadcast',
            inputs={'X': broadcast_name},
            outputs={'Out': broadcast_name},
            attrs={
                'ring_id': ring_id,
                'root': root_device,
                OP_ROLE_KEY: op_role,
            },
        )
876

877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896
    return


DtypeToSize = {
    core.VarDesc.VarType.FP16: 2,
    core.VarDesc.VarType.FP32: 4,
    core.VarDesc.VarType.FP64: 8,
    core.VarDesc.VarType.INT16: 2,
    core.VarDesc.VarType.INT32: 4,
    core.VarDesc.VarType.INT64: 8,
    core.VarDesc.VarType.BOOL: 1,
    core.VarDesc.VarType.UINT8: 1,
}


def get_var_size(param):
    """
    input:
        - param: var
    return:
J
JZ-LIANG 已提交
897
        var size in MB
898 899
    """
    assert -1 not in param.shape
900 901 902 903 904 905
    return (
        reduce(lambda x, y: x * y, param.shape)
        * DtypeToSize[param.dtype]
        / 1024.0
        / 1024.0
    )
906 907 908 909 910 911 912 913 914


def insert_scale_loss_grad_ops(block, scale=1.0):
    '''
    In order to keep the learning rate consistent in different numbers of
    training workers, we scale the loss grad by the number of workers
    '''
    for idx, op in reversed(list(enumerate(block.ops))):
        if is_loss_grad_op(op):
915 916
            assert op.type == 'fill_constant', (
                "loss_grad_op must be fill_constant op, "
917
                "but this op is {}".format(op.type)
918
            )
919 920 921 922
            assert op.has_attr('value')
            loss_scale = float(op.attr('value'))
            loss_scale = loss_scale / scale
            op._set_attr('value', loss_scale)
923
            break
J
JZ-LIANG 已提交
924 925 926 927


def comm_analyse(main_program):
    """
928
    Analyse the parameter size that need to be broadcast/allreduce during sharding training
J
JZ-LIANG 已提交
929 930 931 932 933 934 935
    """
    reduce_vars = {}
    broadcast_vars = {}
    block = main_program.global_block()
    for op in block.ops:
        if op.type == "c_broadcast":
            var_name = op.desc.input_arg_names()[0]
J
JZ-LIANG 已提交
936
            # convert MB to KB
937 938 939
            broadcast_vars[var_name] = (
                get_var_size(block.var(var_name)) * 1024.0
            )
J
JZ-LIANG 已提交
940 941
        elif op.type == "c_allreduce_sum":
            var_name = op.desc.input_arg_names()[0]
J
JZ-LIANG 已提交
942
            reduce_vars[var_name] = get_var_size(block.var(var_name)) * 1024.0
J
JZ-LIANG 已提交
943 944 945 946 947 948

    varsize_count = {}
    gap = 1

    for k, v in broadcast_vars.items():
        print("broadcast: {}: {} KB".format(k, v))
949
        if int(v / gap) in varsize_count:
J
JZ-LIANG 已提交
950 951 952 953 954 955
            varsize_count[int(v / gap)] += 1
        else:
            varsize_count[int(v / gap)] = 1

    for k, v in reduce_vars.items():
        print("allreduce: {}: {} KB".format(k, v))
956
        if int(v / gap) in varsize_count:
J
JZ-LIANG 已提交
957 958 959 960 961 962 963 964
            varsize_count[int(v / gap)] += 1
        else:
            varsize_count[int(v / gap)] = 1

    with open("nccl_size.txt", 'w') as f:
        sorted_varsize = sorted(varsize_count.items(), key=lambda x: x[0])
        for varsize, count in sorted_varsize:
            print("NCCL size {}~{} KB: {}".format(varsize, varsize + 1, count))
965 966 967
            f.write(
                "NCCL size {}~{} KB: {}\n".format(varsize, varsize + 1, count)
            )
J
JZ-LIANG 已提交
968 969


970
def add_sync_comm(program, sharding_ring_id):
J
JZ-LIANG 已提交
971
    """
972
    When clone a test prog by clone from the sharding main prog,
J
JZ-LIANG 已提交
973 974 975 976
    part of the sync_comm op maybe be pruned by mistake, this function
    add the sync_comm op for the test prog.

    """
977
    # NOTE (liangjianzhong): only support one comm stream by now, use more than one
J
JZ-LIANG 已提交
978 979
    # comm streams will cause error. should be revise in future.

980
    assert sharding_ring_id >= 0, "sharding_ring_id should larger than zero"
J
JZ-LIANG 已提交
981 982 983 984 985 986 987 988 989 990
    block = program.global_block()
    not_sync_vars = set([])
    for op in block.ops:
        if op.type in ["c_broadcast", "c_allreduce"]:
            for input_name in op.desc.input_arg_names():
                not_sync_vars.add(input_name)
        if op.type == "c_sync_comm_stream":
            for input_name in op.desc.input_arg_names():
                not_sync_vars.remove(input_name)
    if not_sync_vars:
991 992 993 994 995 996 997 998 999
        block.append_op(
            type='c_sync_comm_stream',
            inputs={'X': list(not_sync_vars)},
            outputs={'Out': list(not_sync_vars)},
            attrs={
                'ring_id': sharding_ring_id,
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
            },
        )
J
JZ-LIANG 已提交
1000 1001 1002
    return


J
JZ-LIANG 已提交
1003
def save_persistables(exe, dirname, main_program, filename=None):
J
JZ-LIANG 已提交
1004 1005 1006 1007 1008
    """
    When use sharding, part of persistable vars are unique and are partitioned in different ranks,
    and part of persistable vars are duplicated and exist in all the ranks with different values.
    This function handles the model saving for sharding training.
    """
1009 1010
    # TODO (JZ-LIANG) revise this for uniform mixed parallelism
    if main_program._pipeline_opt:
L
lilong12 已提交
1011
        main_program = main_program._pipeline_opt['section_program']
J
JZ-LIANG 已提交
1012 1013

    def is_opt_vars(var):
1014
        # NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
1015 1016 1017
        # now only Momentum and adam are compatible with sharding,
        # support EMA optimizer with '_ema_0',
        # support offload with '@offload_0' and '.cast_fp16'
J
JZ-LIANG 已提交
1018
        checks = [
1019 1020 1021 1022 1023 1024 1025 1026
            "_moment1_0",
            "_moment2_0",
            "_beta1_pow_acc_0",
            "_beta2_pow_acc_0",
            "_velocity_0",
            "_ema_0",
            "@offload_0",
            ".cast_fp16",
J
JZ-LIANG 已提交
1027 1028
        ]
        for check in checks:
D
duanboqiang 已提交
1029
            if var.name.endswith(check) and var.persistable:
J
JZ-LIANG 已提交
1030 1031 1032
                return True
        return False

1033 1034 1035 1036 1037
    def is_gradient_merge_vars(var):
        # NOTE(JZ-LIANG): to revise save/load logic in framework instead of write this naive rule

        return var.name.endswith("@GradiantMerge")

J
JZ-LIANG 已提交
1038
    def is_trainable(var):
1039 1040 1041
        return (
            isinstance(var, paddle.fluid.framework.Parameter) and var.trainable
        )
J
JZ-LIANG 已提交
1042 1043

    def sharding_predicate(var):
1044 1045 1046
        return (
            is_trainable(var) or is_opt_vars(var) or is_gradient_merge_vars(var)
        )
J
JZ-LIANG 已提交
1047 1048

    if int(os.environ.get('PADDLE_TRAINER_ID', 0)) == 0:
1049 1050 1051
        paddle.fluid.io.save_persistables(
            exe, dirname, main_program=main_program, filename=None
        )
J
JZ-LIANG 已提交
1052
    else:
1053 1054 1055 1056 1057 1058 1059
        paddle.fluid.io.save_vars(
            exe,
            dirname,
            main_program=main_program,
            predicate=sharding_predicate,
            filename=None,
        )
J
JZ-LIANG 已提交
1060 1061

    return
1062 1063 1064 1065


def append_naive_sync(block, sync_var, ring_id):
    # NOTE (JZ-LIANG) update this to use barrier sync for more elegent logic
1066
    # sync within global
1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091
    block.append_op(
        type="fill_constant",
        outputs={"Out": sync_var},
        attrs={
            "shape": sync_var.shape,
            "dtype": sync_var.dtype,
            "value": int(1),
        },
    )
    block.append_op(
        type='c_allreduce_sum',
        inputs={'X': sync_var},
        outputs={'Out': sync_var},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': True,
            OP_ROLE_KEY: OpRole.Forward,
        },
    )
    block.append_op(
        type='c_sync_calc_stream',
        inputs={'X': [sync_var]},
        outputs={'Out': [sync_var]},
        attrs={OP_ROLE_KEY: OpRole.Forward},
    )