utils.py 38.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
J
JZ-LIANG 已提交
14
import paddle
15
from paddle.fluid import core, unique_name
16
from functools import reduce
17
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op
18
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
19 20

import re
J
JZ-LIANG 已提交
21
import os
22 23 24 25 26 27 28 29 30


def check_broadcast(block):
    """
    if a var is broadcasted, it should have a sync_comm before
    this var is used, if not, raise error.
    if the broadcasted var has a fill_constant op, the fill_constant
    op should stay forward before the broadcast op, and before a
    sync_calc op. Otherwise, raise error.
31 32

    should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
33 34 35 36
    """
    broadcast_vars = {}
    for idx, op in enumerate(block.ops):
        if op.type == "c_broadcast":
37 38 39 40
            if op.all_attrs()["use_calc_stream"] == False:
                var_name = op.desc.input_arg_names()[0]
                if "@BroadCast" in var_name:
                    if var_name in broadcast_vars:
41 42 43 44 45
                        raise ValueError(
                            "var_name areadly exist: {}"
                            "the old pos is {}, the new pos is {}".format(
                                var_name,
                                broadcast_vars[var_name]["broadcast_pos"], idx))
46 47 48 49
                    broadcast_vars[var_name] = {
                        "fill_constant_pos": -1,
                        "broadcast_pos": idx,
                    }
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67

    for idx, op in enumerate(block.ops):
        if op.type == "fill_constant":
            var_name = op.desc.output_arg_names()[0]
            if var_name in broadcast_vars:
                broadcast_vars[var_name]["fill_constant_pos"] = idx
            continue

    last_sync_comm_op_idx = -1
    last_sync_calc_op_idx = -1
    for idx, op in enumerate(block.ops):
        if op.type == "c_sync_comm_stream":
            last_sync_comm_op_idx = idx
            continue
        if op.type == "c_sync_calc_stream":
            last_sync_calc_op_idx = idx
            continue
        if op.type == "c_broadcast":
68 69 70 71 72 73 74 75 76
            if op.all_attrs()["use_calc_stream"] == False:
                var_name = op.desc.input_arg_names()[0]
                if "@BroadCast" in var_name:
                    if broadcast_vars[var_name]["fill_constant_pos"] != -1:
                        assert (last_sync_calc_op_idx != -1)
                        assert (broadcast_vars[var_name]["fill_constant_pos"] <
                                last_sync_calc_op_idx)
                        assert (last_sync_calc_op_idx < idx)
                    continue
77 78 79 80 81 82 83 84 85
        for input_name in op.desc.input_arg_names():
            if input_name in broadcast_vars:
                assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
                assert (broadcast_vars[input_name]["broadcast_pos"] <
                        last_sync_comm_op_idx)
                assert (last_sync_comm_op_idx < idx)
    return


86
def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
87
    """
88 89 90 91
    the op order should be:
        grad:
            - 0: op that generate Var
            - 1: sync_calc
92
            - 2: reduce_sum_sharding (allreduce --> reduce)
93 94 95 96
            - 3: sync_comm
            - 4: allreuce_sum_dp (dp_grads)
            - 5: sync_comm (dp_grads)
            - 6: op that use Var (dp_grads & sum)
97 98

    should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
99
    """
100 101 102 103 104
    vars_status = {}
    dp_grads_status = {}
    idx_last_grad_allreduce = -1
    idx_amp_allreduce = -1
    idx_gradient_clip_allreduce = -1
105

106
    for idx, op in enumerate(block.ops):
107 108 109 110 111 112
        # sharding use both allreduce and reduce to sync grad
        if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
            if op.all_attrs()["use_calc_stream"] == False:
                ring_id = op.desc.attr("ring_id")
                var_name = op.desc.input_arg_names()[0]
                param = var_name.split("@")[0]
113

114 115 116 117 118
                assert 'sum' in var_name or ("@GRAD" in var_name)
                if 'sum' in var_name or (not shard.has_param(param)):
                    vars_status[var_name] = -1
                else:
                    dp_grads_status[var_name] = -1
119

120 121 122
                if ring_id != sharding_ring_id:
                    assert shard.has_param(param)
                    assert ring_id == dp_ring_id
123

124 125 126 127
                if "sum" in var_name:
                    idx_amp_allreduce = idx
                elif "@GRAD":
                    idx_last_grad_allreduce = idx
128 129 130

        if op.type == "c_allreduce_max":
            idx_gradient_clip_allreduce = idx
131 132 133

    for op in block.ops:
        if op.type == "c_sync_calc_stream":
134 135 136 137 138 139 140
            for var_name in vars_status:
                if var_name in vars_status and vars_status[var_name] == 0:
                    vars_status[var_name] = 1
            for var_name in dp_grads_status:
                if var_name in dp_grads_status and dp_grads_status[
                        var_name] == 0:
                    dp_grads_status[var_name] = 1
141 142 143 144 145 146 147 148 149 150 151 152
        # check sharding allreduce and  reduce but skip megatron allreduce
        elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
            if op.all_attrs()["use_calc_stream"] == False:
                var_name = op.desc.input_arg_names()[0]
                ring_id = op.desc.attr("ring_id")
                if ring_id == sharding_ring_id:
                    assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce"
                    if var_name in vars_status:
                        _status = vars_status[var_name]
                    else:
                        _status = dp_grads_status[var_name]
                    if _status == -1:
153 154 155
                        raise ValueError(
                            "{} is not generated, but you are"
                            "trying to all-reduce it".format(var_name))
156 157 158 159 160 161 162 163 164
                    if _status == 0:
                        raise ValueError("There should be a sync_calc op "
                                         "after generate Var: {} and before the"
                                         "c_allreduce_sum op".format(var_name))
                    assert (_status == 1)
                    if var_name in vars_status:
                        vars_status[var_name] = 2
                    else:
                        dp_grads_status[var_name] = 2
165
                else:
166 167 168 169 170
                    assert ring_id == dp_ring_id
                    param = var_name.split("@")[0]
                    assert shard.has_param(param)
                    assert dp_grads_status[var_name] == 3
                    dp_grads_status[var_name] = 4
171

172
        elif op.type == "c_sync_comm_stream":
173 174
            var_name = op.desc.input_arg_names()[0]
            ring_id = op.desc.attr("ring_id")
175
            if ring_id == sharding_ring_id:
176 177 178 179 180 181 182 183 184 185 186 187 188 189
                for var_name in op.desc.input_arg_names():
                    if var_name in vars_status:
                        assert vars_status[var_name] == 2
                        vars_status[var_name] = 3
                    elif var_name in dp_grads_status:
                        assert dp_grads_status[var_name] == 2
                        dp_grads_status[var_name] = 3
            else:
                for var_name in op.desc.input_arg_names():
                    param = var_name.split("@")[0]
                    assert ring_id == dp_ring_id
                    assert shard.has_param(param)
                    assert dp_grads_status[var_name] == 4
                    dp_grads_status[var_name] = 5
190 191
        else:
            for input_name in op.desc.input_arg_names():
192 193
                if input_name in vars_status:
                    if vars_status[input_name] != 3:
194 195 196
                        raise ValueError(
                            "There should be a sync_comm op "
                            "after allreduce the Var: {}".format(input_name))
197
                    raise ValueError(
198 199
                        "The reduce output grad [{}] should NOT be be used in Non-root rank."
                        .format(input_name))
200 201 202
                if input_name in dp_grads_status:
                    if dp_ring_id == -1:
                        if dp_grads_status[input_name] != 3:
203 204 205 206
                            raise ValueError(
                                "There should be a sync_comm op "
                                "after allreduce the Var: {}".format(
                                    input_name))
207 208 209 210 211 212
                    else:
                        if dp_grads_status[input_name] != 5:
                            raise ValueError(
                                "The grad in shard should be allreduce and sync"
                                "twice before usage {}".format(input_name))

213
            for output_name in op.desc.output_arg_names():
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
                if output_name in vars_status and \
                    vars_status[output_name] == -1:
                    vars_status[output_name] = 0
                if output_name in dp_grads_status and  \
                    dp_grads_status[output_name] == -1:
                    dp_grads_status[output_name] = 0

    # check sharding with amp
    if idx_amp_allreduce != -1:
        assert idx_amp_allreduce > idx_last_grad_allreduce

    # check sharding with gradient_clip_by_global_norm
    if idx_gradient_clip_allreduce != -1:
        assert idx_gradient_clip_allreduce > idx_last_grad_allreduce

229 230 231
    return


J
JZ-LIANG 已提交
232 233 234 235 236
def get_valid_op_role(block, insert_idx):
    """
    return OpRole.Forward or OpRole.Backward
    """
    op_role = block.ops[insert_idx].attr('op_role')
237 238 239
    if (insert_idx >= len(block.ops)) or (op_role in [
            int(OpRole.Backward), int(OpRole.Optimize)
    ]):
J
JZ-LIANG 已提交
240 241 242 243 244 245 246
        return OpRole.Backward
    if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
        return OpRole.Forward

    return get_valid_op_role(block, insert_idx + 1)


247 248 249 250
def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
    """
    _insert_sync_calc_op
    """
J
JZ-LIANG 已提交
251
    op_role = get_valid_op_role(block, insert_idx)
252 253 254 255 256
    block._insert_op_without_sync(insert_idx,
                                  type='c_sync_calc_stream',
                                  inputs={'X': calc_dep_vars},
                                  outputs={'Out': calc_dep_vars},
                                  attrs={OP_ROLE_KEY: op_role})
257 258 259
    return


260
def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars):
261
    """
262
    insert sync_comm_op for single var
263
    """
J
JZ-LIANG 已提交
264
    op_role = get_valid_op_role(block, insert_idx)
265 266 267 268 269 270 271 272
    block._insert_op_without_sync(insert_idx,
                                  type='c_sync_comm_stream',
                                  inputs={'X': comm_dep_vars},
                                  outputs={'Out': comm_dep_vars},
                                  attrs={
                                      'ring_id': ring_id,
                                      OP_ROLE_KEY: op_role
                                  })
273 274 275 276 277 278 279
    return 1


def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
    """
    insert sync_comm_op for vars
    """
280
    # NOTE (JZ-LIANG) to be check, may result undefined case
281 282 283
    if len(comm_dep_vars) == 0:
        return 0

284
    op_role = get_valid_op_role(block, insert_idx)
285 286 287 288 289 290 291 292
    block._insert_op_without_sync(insert_idx,
                                  type='c_sync_comm_stream',
                                  inputs={'X': comm_dep_vars},
                                  outputs={'Out': comm_dep_vars},
                                  attrs={
                                      'ring_id': int(ring_id),
                                      OP_ROLE_KEY: op_role
                                  })
293
    return 1
294 295 296 297 298 299


def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
    """
    _add_fill_constant_ops
    """
J
JZ-LIANG 已提交
300
    op_role = get_valid_op_role(block, insert_idx)
301 302
    for broadcast_name in fill_constant_vars:
        broadcast_var = block.var(broadcast_name)
303 304 305 306 307 308 309 310 311
        block._insert_op_without_sync(insert_idx,
                                      type="fill_constant",
                                      outputs={"Out": broadcast_var.name},
                                      attrs={
                                          "shape": broadcast_var.shape,
                                          "dtype": broadcast_var.dtype,
                                          "value": 0.0,
                                          OP_ROLE_KEY: op_role
                                      })
312 313 314 315 316 317 318
    return


def insert_cast_ops(block, insert_idx, cast_ops):
    """
    _add_cast_ops
    """
J
JZ-LIANG 已提交
319
    op_role = get_valid_op_role(block, insert_idx)
320
    for fp16_name, fp32_name in cast_ops.items():
321 322 323 324 325 326 327 328 329 330
        block._insert_op_without_sync(insert_idx,
                                      type="cast",
                                      inputs={"X": fp32_name},
                                      outputs={"Out": fp16_name},
                                      attrs={
                                          "in_dtype": core.VarDesc.VarType.FP32,
                                          "out_dtype":
                                          core.VarDesc.VarType.FP16,
                                          OP_ROLE_KEY: op_role
                                      })
331 332 333
    return


334 335 336 337 338
def insert_allreduce_ops(block,
                         insert_idx,
                         ring_id,
                         allreduce_vars,
                         op_role=OpRole.Backward,
339 340
                         use_calc_stream=False,
                         user_defined_strategy=None):
341 342 343
    """
    _add_allreduce_ops
    """
344 345 346
    if len(allreduce_vars) == 0:
        return

347 348 349 350 351
    if user_defined_strategy and \
            user_defined_strategy.fuse_all_reduce_ops and \
            not user_defined_strategy.fuse_grad_merge:
        # If fuse_grad_merge is enable, the grad vars have already been fused during
        # gradient merge pass, therefore, those vars are not need to be fused here
352 353 354 355 356
        insert_fused_allreduce_ops(block, insert_idx, ring_id, allreduce_vars,
                                   op_role, use_calc_stream,
                                   user_defined_strategy.fuse_grad_size_in_MB)
    else:
        for var in allreduce_vars:
357 358 359 360 361 362 363 364 365 366
            block._insert_op_without_sync(insert_idx,
                                          type='c_allreduce_sum',
                                          inputs={'X': var},
                                          outputs={'Out': var},
                                          attrs={
                                              'ring_id': ring_id,
                                              'use_calc_stream':
                                              use_calc_stream,
                                              OP_ROLE_KEY: op_role
                                          })
367 368 369 370

    return


371
class FuseHelper(object):
372

373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
    @staticmethod
    def sort_vars_by_dtype(block, vars_name):
        fp32_vars = []
        fp16_vars = []
        other_vars = []
        for var in vars_name:
            dtype = block.var(var).dtype
            if dtype == paddle.float32:
                fp32_vars.append(var)
            elif dtype == paddle.float16:
                fp16_vars.append(var)
            else:
                other_vars.append(var)
        assert len(other_vars) == 0, "only support fp32/fp16 vars for fuse"

        fp32_vars.extend(fp16_vars)
        return fp32_vars

391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
    @staticmethod
    def get_fused_groups(block, vars_name, fuse_size=32.):
        """ coalesce tensor, get fused group """
        groups = []
        cur_size = 0.
        last_dtype = None
        for var_name in vars_name:
            real_var = block.var(var_name)
            var_size = get_var_size(real_var)
            if cur_size + var_size > fuse_size \
                    or len(groups) == 0 \
                    or real_var.dtype != last_dtype:
                groups.append([real_var])
                cur_size = var_size
                last_dtype = real_var.dtype
            else:
                groups[-1].append(real_var)
                cur_size += var_size
        return groups

    @staticmethod
    def insert_coalesce_tensor(block,
                               index,
                               groups,
                               op_role=OpRole.Backward,
                               prefix="Output"):
        fused_vars = []
        insert_num = 0
        for group in groups:
            assert len(group) >= 1
            if len(group) == 1:
                # no need fuse
                fused_vars.append(group[0])
                continue

426 427 428 429 430
            fused_var = block.create_var(name=unique_name.generate(
                'Fused{}_{}'.format(prefix, group[0].name)),
                                         dtype=group[0].dtype,
                                         persistable=False,
                                         stop_gradient=True)
431
            fused_vars.append(fused_var)
432 433 434 435 436 437 438 439 440 441 442 443 444
            block._insert_op_without_sync(index,
                                          type="coalesce_tensor",
                                          inputs={"Input": group},
                                          outputs={
                                              "Output": group,
                                              "FusedOutput": fused_var
                                          },
                                          attrs={
                                              "copy_data": True,
                                              "use_align": True,
                                              "dtype": group[0].dtype,
                                              OP_ROLE_KEY: op_role
                                          })
445 446 447 448
            insert_num += 1
        return fused_vars, insert_num


449 450 451 452 453 454 455
def insert_fused_allreduce_ops(block,
                               insert_idx,
                               ring_id,
                               allreduce_vars,
                               op_role=OpRole.Backward,
                               use_calc_stream=False,
                               fuse_grad_size_in_MB=32):
456 457 458
    groups = FuseHelper.get_fused_groups(block, allreduce_vars,
                                         fuse_grad_size_in_MB)

459 460 461 462 463
    fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(block,
                                                               insert_idx,
                                                               groups,
                                                               op_role,
                                                               prefix="Grad")
464 465

    for fused_var in fused_vars:
466 467 468 469 470 471 472 473 474
        block._insert_op_without_sync(insert_idx + insert_num,
                                      type='c_allreduce_sum',
                                      inputs={'X': fused_var},
                                      outputs={'Out': fused_var},
                                      attrs={
                                          'ring_id': ring_id,
                                          'use_calc_stream': use_calc_stream,
                                          OP_ROLE_KEY: op_role
                                      })
475
        if not use_calc_stream:
476 477 478 479 480
            block._insert_op_without_sync(insert_idx + insert_num,
                                          type='c_sync_calc_stream',
                                          inputs={'X': fused_var},
                                          outputs={'Out': fused_var},
                                          attrs={OP_ROLE_KEY: op_role})
481 482


483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
def insert_fused_reduce_ops(block,
                            insert_idx,
                            ring_id,
                            reduce_vars,
                            shard,
                            op_role=OpRole.Backward,
                            use_calc_stream=False,
                            rank=None,
                            fuse_grad_size=32):
    nranks = shard.worker_num
    device_to_vars = [[] for _ in range(nranks)]

    for var in reduce_vars:
        root_id = get_grad_device(var, shard)
        assert 0 <= root_id < nranks, "root_id should >=0 and < nranks, " \
            "but now nranks={}, the root_id of var={} is {}"\
            .format(nranks, var, root_id)
        device_to_vars[root_id].append(var)

    for root_id, vars_name in enumerate(device_to_vars):
        groups = FuseHelper.get_fused_groups(block, vars_name, fuse_grad_size)

        fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
            block, insert_idx, groups, op_role, prefix="Grad")

        for fused_var in fused_vars:
509 510 511 512 513 514 515 516 517 518 519
            block._insert_op_without_sync(insert_idx + insert_num,
                                          type='c_reduce_sum',
                                          inputs={'X': fused_var},
                                          outputs={'Out': fused_var},
                                          attrs={
                                              'ring_id': ring_id,
                                              'root_id': root_id,
                                              'use_calc_stream':
                                              use_calc_stream,
                                              OP_ROLE_KEY: op_role
                                          })
520
            if not use_calc_stream:
521 522 523 524 525
                block._insert_op_without_sync(insert_idx + insert_num,
                                              type='c_sync_calc_stream',
                                              inputs={'X': fused_var},
                                              outputs={'Out': fused_var},
                                              attrs={OP_ROLE_KEY: op_role})
526 527 528 529

    return [] if rank is None else device_to_vars[rank]


530 531 532 533 534 535
def insert_reduce_ops(block,
                      insert_idx,
                      ring_id,
                      reduce_vars,
                      shard,
                      op_role=OpRole.Backward,
536
                      use_calc_stream=False,
537 538
                      rank=None,
                      strategy=None):
539
    """
540
    _add_reduce_ops
541
    """
542 543 544 545 546 547
    if strategy and strategy.fuse_all_reduce_ops and \
            not strategy.fuse_grad_merge:
        return insert_fused_reduce_ops(block, insert_idx, ring_id, reduce_vars,
                                       shard, op_role, use_calc_stream, rank,
                                       strategy.fuse_grad_size_in_MB)

548
    grad_in_this_device = []
549
    for var in reduce_vars:
550 551 552 553 554 555 556
        grad_var = var
        if strategy and strategy.fuse_all_reduce_ops and \
                strategy.fuse_grad_merge:
            # TODO(wangxi): if support fp16_allreduce, need be
            # 'FusedMergedGrad.cast_fp16._'
            grad_var = var.replace('FusedMergedGrad_', '')
        root_id = get_grad_device(grad_var, shard)
Z
zhangchunle 已提交
557 558
        assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
            root_id)
559 560
        if rank is not None and rank == root_id:
            grad_in_this_device.append(var)
561 562 563 564 565 566 567 568 569 570
        block._insert_op_without_sync(insert_idx,
                                      type='c_reduce_sum',
                                      inputs={'X': var},
                                      outputs={'Out': var},
                                      attrs={
                                          'ring_id': ring_id,
                                          'root_id': root_id,
                                          'use_calc_stream': use_calc_stream,
                                          OP_ROLE_KEY: op_role
                                      })
571 572

    return grad_in_this_device
573 574


575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
def insert_fused_broadcast_param_ops(block,
                                     insert_idx,
                                     ring_id,
                                     params,
                                     shard,
                                     op_role=OpRole.Optimize,
                                     use_calc_stream=False,
                                     rank=None,
                                     fuse_size=32):
    nranks = shard.worker_num
    device_to_vars = [[] for _ in range(nranks)]

    for var in params:
        root_id = shard.device(var)
        assert 0 <= root_id < nranks, "root_id should >=0 and < nranks, " \
            "but now nranks={}, the root_id of var={} is {}"\
            .format(nranks, var, root_id)
        device_to_vars[root_id].append(var)

    for root_id, vars_name in enumerate(device_to_vars):
        groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size)

        fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
            block, insert_idx, groups, op_role, prefix="Param")

        for fused_var in fused_vars:
601 602 603 604 605 606 607 608 609 610 611
            block._insert_op_without_sync(insert_idx + insert_num,
                                          type='c_broadcast',
                                          inputs={'X': fused_var},
                                          outputs={'Out': fused_var},
                                          attrs={
                                              'ring_id': ring_id,
                                              'root': root_id,
                                              'use_calc_stream':
                                              use_calc_stream,
                                              OP_ROLE_KEY: op_role
                                          })
612
            if not use_calc_stream:
613 614 615 616 617
                block._insert_op_without_sync(insert_idx + insert_num,
                                              type='c_sync_calc_stream',
                                              inputs={'X': fused_var},
                                              outputs={'Out': fused_var},
                                              attrs={OP_ROLE_KEY: op_role})
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635

    return [] if rank is None else device_to_vars[rank]


def insert_broadcast_param_ops(block,
                               insert_idx,
                               ring_id,
                               params,
                               shard,
                               op_role=OpRole.Optimize,
                               use_calc_stream=False,
                               rank=None,
                               strategy=None):
    """
    add broadcast param ops
    """
    if strategy and strategy.fuse_all_reduce_ops:
        # TODO(wangxi): put fused var in startup_program, only need exec once
636 637 638 639
        return insert_fused_broadcast_param_ops(block, insert_idx, ring_id,
                                                params, shard, op_role,
                                                use_calc_stream, rank,
                                                strategy.fuse_grad_size_in_MB)
640 641 642 643 644 645 646 647

    param_in_this_device = []
    for param in params:
        root_id = shard.device(param)
        assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
            root_id)
        if rank is not None and rank == root_id:
            param_in_this_device.append(param)
648 649 650 651 652 653 654 655 656 657
        block._insert_op_without_sync(insert_idx,
                                      type='c_broadcast',
                                      inputs={'X': param},
                                      outputs={'Out': param},
                                      attrs={
                                          'ring_id': ring_id,
                                          'root': root_id,
                                          'use_calc_stream': use_calc_stream,
                                          OP_ROLE_KEY: op_role
                                      })
658 659 660 661

    return param_in_this_device


662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
def fuse_opt_broadcast_param_ops(block,
                                 ring_id,
                                 shard,
                                 op_role=OpRole.Optimize,
                                 strategy=None):
    """
    fuse optimizer sharding broadcast param ops
    """
    if strategy is None or not strategy.fuse_all_reduce_ops:
        return

    fuse_size = strategy.fuse_grad_size_in_MB

    nranks = shard.worker_num
    device_to_vars = [[] for _ in range(nranks)]

    for idx, op in reversed(list(enumerate(block.ops))):
        if not is_optimizer_op(op) or op.type != 'c_broadcast':
            break
        var = op.input_arg_names[0]
        root_id = op.attr('root')
        device_to_vars[root_id].insert(0, var)
        block._remove_op(idx, sync=False)

    insert_idx = idx + 1
    for root_id, vars_name in enumerate(device_to_vars):
        vars_name = FuseHelper.sort_vars_by_dtype(block, vars_name)
        groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size)

        fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
            block, insert_idx, groups, op_role, prefix="Param")

        for fused_var in fused_vars:
695 696 697 698 699 700 701 702 703 704
            block._insert_op_without_sync(insert_idx + insert_num,
                                          type='c_broadcast',
                                          inputs={'X': fused_var},
                                          outputs={'Out': fused_var},
                                          attrs={
                                              'ring_id': ring_id,
                                              'root': root_id,
                                              'use_calc_stream': True,
                                              OP_ROLE_KEY: op_role
                                          })
705 706 707 708

    block._sync_with_cpp()


709 710 711 712
def get_grad_device(grad_name, shard):
    assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
        grad_name)
    base_name = None
713
    # NOTE: mind the traversal order
714
    possible_suffixes = [
715 716 717 718 719 720 721
        # sharding gm
        '.cast_fp16@GRAD@MERGED',
        '.cast_fp16@GRAD',
        # pipeline
        '@GRAD@MERGED@FP16',
        '@GRAD@MERGED',
        '@GRAD',
722 723 724 725 726 727 728 729 730 731 732 733
    ]
    for suffix in possible_suffixes:
        if suffix in grad_name:
            base_name = re.sub(suffix, '', grad_name)
            break

    assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
        base_name)

    return shard.global_param2device[base_name]


B
Baibaifan 已提交
734
def get_first_check_finite_and_unscale_op_idx(block, raise_error=True):
735 736 737 738 739

    for idx, op in enumerate(block.ops):
        if op.type == "check_finite_and_unscale":
            return idx

B
Baibaifan 已提交
740 741 742 743 744 745
    if raise_error:
        raise ValueError(
            "amp is turned on but check_finite_and_unscale op does not exist in main block"
        )

    return -1
746 747


748 749 750 751 752 753 754 755 756
def get_first_optimize_op_idx(block):
    first_opt_op_idx = None
    for index, op in reversed(tuple(enumerate(block.ops))):
        if is_backward_op(op) and first_opt_op_idx is None:
            first_opt_op_idx = index + 1
            break
    return first_opt_op_idx


757
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
758 759 760
    """
    _add_broadcast_ops
    """
J
JZ-LIANG 已提交
761
    op_role = get_valid_op_role(block, insert_idx)
762
    for broadcast_name, root_device in broadcast2root:
763 764 765 766 767 768 769 770 771
        block._insert_op_without_sync(insert_idx,
                                      type='c_broadcast',
                                      inputs={'X': broadcast_name},
                                      outputs={'Out': broadcast_name},
                                      attrs={
                                          'ring_id': ring_id,
                                          'root': root_device,
                                          OP_ROLE_KEY: op_role
                                      })
772

773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792
    return


DtypeToSize = {
    core.VarDesc.VarType.FP16: 2,
    core.VarDesc.VarType.FP32: 4,
    core.VarDesc.VarType.FP64: 8,
    core.VarDesc.VarType.INT16: 2,
    core.VarDesc.VarType.INT32: 4,
    core.VarDesc.VarType.INT64: 8,
    core.VarDesc.VarType.BOOL: 1,
    core.VarDesc.VarType.UINT8: 1,
}


def get_var_size(param):
    """
    input:
        - param: var
    return:
J
JZ-LIANG 已提交
793
        var size in MB
794 795 796 797 798 799 800 801 802 803 804 805 806
    """
    assert -1 not in param.shape
    return reduce(lambda x, y: x * y,
                  param.shape) * DtypeToSize[param.dtype] / 1024.0 / 1024.0


def insert_scale_loss_grad_ops(block, scale=1.0):
    '''
    In order to keep the learning rate consistent in different numbers of
    training workers, we scale the loss grad by the number of workers
    '''
    for idx, op in reversed(list(enumerate(block.ops))):
        if is_loss_grad_op(op):
807 808 809 810 811 812 813
            assert op.type == 'fill_constant', \
                "loss_grad_op must be fill_constant op, " \
                "but this op is {}".format(op.type)
            assert op.has_attr('value')
            loss_scale = float(op.attr('value'))
            loss_scale = loss_scale / scale
            op._set_attr('value', loss_scale)
814
            break
J
JZ-LIANG 已提交
815 816 817 818


def comm_analyse(main_program):
    """
819
    Analyse the parameter size that need to be broadcast/allreduce during sharding training
J
JZ-LIANG 已提交
820 821 822 823 824 825 826
    """
    reduce_vars = {}
    broadcast_vars = {}
    block = main_program.global_block()
    for op in block.ops:
        if op.type == "c_broadcast":
            var_name = op.desc.input_arg_names()[0]
J
JZ-LIANG 已提交
827
            # convert MB to KB
828 829
            broadcast_vars[var_name] = get_var_size(
                block.var(var_name)) * 1024.0
J
JZ-LIANG 已提交
830 831
        elif op.type == "c_allreduce_sum":
            var_name = op.desc.input_arg_names()[0]
J
JZ-LIANG 已提交
832
            reduce_vars[var_name] = get_var_size(block.var(var_name)) * 1024.0
J
JZ-LIANG 已提交
833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858

    varsize_count = {}
    gap = 1

    for k, v in broadcast_vars.items():
        print("broadcast: {}: {} KB".format(k, v))
        if (int(v / gap) in varsize_count):
            varsize_count[int(v / gap)] += 1
        else:
            varsize_count[int(v / gap)] = 1

    for k, v in reduce_vars.items():
        print("allreduce: {}: {} KB".format(k, v))
        if (int(v / gap) in varsize_count):
            varsize_count[int(v / gap)] += 1
        else:
            varsize_count[int(v / gap)] = 1

    with open("nccl_size.txt", 'w') as f:
        sorted_varsize = sorted(varsize_count.items(), key=lambda x: x[0])
        for varsize, count in sorted_varsize:
            print("NCCL size {}~{} KB: {}".format(varsize, varsize + 1, count))
            f.write("NCCL size {}~{} KB: {}\n".format(varsize, varsize + 1,
                                                      count))


859
def add_sync_comm(program, sharding_ring_id):
J
JZ-LIANG 已提交
860
    """
861
    When clone a test prog by clone from the sharding main prog,
J
JZ-LIANG 已提交
862 863 864 865
    part of the sync_comm op maybe be pruned by mistake, this function
    add the sync_comm op for the test prog.

    """
866
    #NOTE (liangjianzhong): only support one comm stream by now, use more than one
J
JZ-LIANG 已提交
867 868
    # comm streams will cause error. should be revise in future.

869
    assert sharding_ring_id >= 0, "sharding_ring_id should larger than zero"
J
JZ-LIANG 已提交
870 871 872 873 874 875 876 877 878 879
    block = program.global_block()
    not_sync_vars = set([])
    for op in block.ops:
        if op.type in ["c_broadcast", "c_allreduce"]:
            for input_name in op.desc.input_arg_names():
                not_sync_vars.add(input_name)
        if op.type == "c_sync_comm_stream":
            for input_name in op.desc.input_arg_names():
                not_sync_vars.remove(input_name)
    if not_sync_vars:
880 881 882 883 884 885 886 887 888
        block.append_op(type='c_sync_comm_stream',
                        inputs={'X': list(not_sync_vars)},
                        outputs={'Out': list(not_sync_vars)},
                        attrs={
                            'ring_id':
                            sharding_ring_id,
                            'op_role':
                            core.op_proto_and_checker_maker.OpRole.Forward
                        })
J
JZ-LIANG 已提交
889 890 891
    return


J
JZ-LIANG 已提交
892
def save_persistables(exe, dirname, main_program, filename=None):
J
JZ-LIANG 已提交
893 894 895 896 897
    """
    When use sharding, part of persistable vars are unique and are partitioned in different ranks,
    and part of persistable vars are duplicated and exist in all the ranks with different values.
    This function handles the model saving for sharding training.
    """
898 899
    # TODO (JZ-LIANG) revise this for uniform mixed parallelism
    if main_program._pipeline_opt:
L
lilong12 已提交
900
        main_program = main_program._pipeline_opt['section_program']
J
JZ-LIANG 已提交
901 902

    def is_opt_vars(var):
903
        # NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
904 905 906
        # now only Momentum and adam are compatible with sharding,
        # support EMA optimizer with '_ema_0',
        # support offload with '@offload_0' and '.cast_fp16'
J
JZ-LIANG 已提交
907 908
        checks = [
            "_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0",
909
            "_velocity_0", "_ema_0", "@offload_0", ".cast_fp16"
J
JZ-LIANG 已提交
910 911
        ]
        for check in checks:
D
duanboqiang 已提交
912
            if var.name.endswith(check) and var.persistable:
J
JZ-LIANG 已提交
913 914 915
                return True
        return False

916 917 918 919 920
    def is_gradient_merge_vars(var):
        # NOTE(JZ-LIANG): to revise save/load logic in framework instead of write this naive rule

        return var.name.endswith("@GradiantMerge")

J
JZ-LIANG 已提交
921 922 923 924 925
    def is_trainable(var):
        return isinstance(var,
                          paddle.fluid.framework.Parameter) and var.trainable

    def sharding_predicate(var):
926 927
        return is_trainable(var) or is_opt_vars(var) or is_gradient_merge_vars(
            var)
J
JZ-LIANG 已提交
928 929

    if int(os.environ.get('PADDLE_TRAINER_ID', 0)) == 0:
930 931 932 933
        paddle.fluid.io.save_persistables(exe,
                                          dirname,
                                          main_program=main_program,
                                          filename=None)
J
JZ-LIANG 已提交
934
    else:
935 936 937 938 939
        paddle.fluid.io.save_vars(exe,
                                  dirname,
                                  main_program=main_program,
                                  predicate=sharding_predicate,
                                  filename=None)
J
JZ-LIANG 已提交
940 941

    return
942 943 944 945


def append_naive_sync(block, sync_var, ring_id):
    # NOTE (JZ-LIANG) update this to use barrier sync for more elegent logic
946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965
    # sync within global
    block.append_op(type="fill_constant",
                    outputs={"Out": sync_var},
                    attrs={
                        "shape": sync_var.shape,
                        "dtype": sync_var.dtype,
                        "value": int(1),
                    })
    block.append_op(type='c_allreduce_sum',
                    inputs={'X': sync_var},
                    outputs={'Out': sync_var},
                    attrs={
                        'ring_id': ring_id,
                        'use_calc_stream': True,
                        OP_ROLE_KEY: OpRole.Forward
                    })
    block.append_op(type='c_sync_calc_stream',
                    inputs={'X': [sync_var]},
                    outputs={'Out': [sync_var]},
                    attrs={OP_ROLE_KEY: OpRole.Forward})