utils.py 25.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
J
JZ-LIANG 已提交
14
import paddle
15
from paddle.fluid import core, unique_name
16 17 18 19 20
from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY

import re
J
JZ-LIANG 已提交
21
import os
22 23 24 25 26 27 28 29 30


def check_broadcast(block):
    """
    if a var is broadcasted, it should have a sync_comm before
    this var is used, if not, raise error.
    if the broadcasted var has a fill_constant op, the fill_constant
    op should stay forward before the broadcast op, and before a
    sync_calc op. Otherwise, raise error.
31 32

    should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
33 34 35 36
    """
    broadcast_vars = {}
    for idx, op in enumerate(block.ops):
        if op.type == "c_broadcast":
37 38 39 40 41 42 43 44 45 46 47 48
            if op.all_attrs()["use_calc_stream"] == False:
                var_name = op.desc.input_arg_names()[0]
                if "@BroadCast" in var_name:
                    if var_name in broadcast_vars:
                        raise ValueError("var_name areadly exist: {}"
                                         "the old pos is {}, the new pos is {}".
                                         format(var_name, broadcast_vars[
                                             var_name]["broadcast_pos"], idx))
                    broadcast_vars[var_name] = {
                        "fill_constant_pos": -1,
                        "broadcast_pos": idx,
                    }
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

    for idx, op in enumerate(block.ops):
        if op.type == "fill_constant":
            var_name = op.desc.output_arg_names()[0]
            if var_name in broadcast_vars:
                broadcast_vars[var_name]["fill_constant_pos"] = idx
            continue

    last_sync_comm_op_idx = -1
    last_sync_calc_op_idx = -1
    for idx, op in enumerate(block.ops):
        if op.type == "c_sync_comm_stream":
            last_sync_comm_op_idx = idx
            continue
        if op.type == "c_sync_calc_stream":
            last_sync_calc_op_idx = idx
            continue
        if op.type == "c_broadcast":
67 68 69 70 71 72 73 74 75
            if op.all_attrs()["use_calc_stream"] == False:
                var_name = op.desc.input_arg_names()[0]
                if "@BroadCast" in var_name:
                    if broadcast_vars[var_name]["fill_constant_pos"] != -1:
                        assert (last_sync_calc_op_idx != -1)
                        assert (broadcast_vars[var_name]["fill_constant_pos"] <
                                last_sync_calc_op_idx)
                        assert (last_sync_calc_op_idx < idx)
                    continue
76 77 78 79 80 81 82 83 84
        for input_name in op.desc.input_arg_names():
            if input_name in broadcast_vars:
                assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
                assert (broadcast_vars[input_name]["broadcast_pos"] <
                        last_sync_comm_op_idx)
                assert (last_sync_comm_op_idx < idx)
    return


85
def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
86
    """
87 88 89 90
    the op order should be:
        grad:
            - 0: op that generate Var
            - 1: sync_calc
91
            - 2: reduce_sum_sharding (allreduce --> reduce)
92 93 94 95
            - 3: sync_comm
            - 4: allreuce_sum_dp (dp_grads)
            - 5: sync_comm (dp_grads)
            - 6: op that use Var (dp_grads & sum)
96 97

    should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
98
    """
99 100 101 102 103
    vars_status = {}
    dp_grads_status = {}
    idx_last_grad_allreduce = -1
    idx_amp_allreduce = -1
    idx_gradient_clip_allreduce = -1
104

105
    for idx, op in enumerate(block.ops):
106 107 108 109 110 111
        # sharding use both allreduce and reduce to sync grad
        if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
            if op.all_attrs()["use_calc_stream"] == False:
                ring_id = op.desc.attr("ring_id")
                var_name = op.desc.input_arg_names()[0]
                param = var_name.split("@")[0]
112

113 114 115 116 117
                assert 'sum' in var_name or ("@GRAD" in var_name)
                if 'sum' in var_name or (not shard.has_param(param)):
                    vars_status[var_name] = -1
                else:
                    dp_grads_status[var_name] = -1
118

119 120 121
                if ring_id != sharding_ring_id:
                    assert shard.has_param(param)
                    assert ring_id == dp_ring_id
122

123 124 125 126
                if "sum" in var_name:
                    idx_amp_allreduce = idx
                elif "@GRAD":
                    idx_last_grad_allreduce = idx
127 128 129

        if op.type == "c_allreduce_max":
            idx_gradient_clip_allreduce = idx
130 131 132

    for op in block.ops:
        if op.type == "c_sync_calc_stream":
133 134 135 136 137 138 139
            for var_name in vars_status:
                if var_name in vars_status and vars_status[var_name] == 0:
                    vars_status[var_name] = 1
            for var_name in dp_grads_status:
                if var_name in dp_grads_status and dp_grads_status[
                        var_name] == 0:
                    dp_grads_status[var_name] = 1
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        # check sharding allreduce and  reduce but skip megatron allreduce
        elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
            if op.all_attrs()["use_calc_stream"] == False:
                var_name = op.desc.input_arg_names()[0]
                ring_id = op.desc.attr("ring_id")
                if ring_id == sharding_ring_id:
                    assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce"
                    if var_name in vars_status:
                        _status = vars_status[var_name]
                    else:
                        _status = dp_grads_status[var_name]
                    if _status == -1:
                        raise ValueError("{} is not generated, but you are"
                                         "trying to all-reduce it".format(
                                             var_name))
                    if _status == 0:
                        raise ValueError("There should be a sync_calc op "
                                         "after generate Var: {} and before the"
                                         "c_allreduce_sum op".format(var_name))
                    assert (_status == 1)
                    if var_name in vars_status:
                        vars_status[var_name] = 2
                    else:
                        dp_grads_status[var_name] = 2
164
                else:
165 166 167 168 169
                    assert ring_id == dp_ring_id
                    param = var_name.split("@")[0]
                    assert shard.has_param(param)
                    assert dp_grads_status[var_name] == 3
                    dp_grads_status[var_name] = 4
170

171
        elif op.type == "c_sync_comm_stream":
172 173
            var_name = op.desc.input_arg_names()[0]
            ring_id = op.desc.attr("ring_id")
174
            if ring_id == sharding_ring_id:
175 176 177 178 179 180 181 182 183 184 185 186 187 188
                for var_name in op.desc.input_arg_names():
                    if var_name in vars_status:
                        assert vars_status[var_name] == 2
                        vars_status[var_name] = 3
                    elif var_name in dp_grads_status:
                        assert dp_grads_status[var_name] == 2
                        dp_grads_status[var_name] = 3
            else:
                for var_name in op.desc.input_arg_names():
                    param = var_name.split("@")[0]
                    assert ring_id == dp_ring_id
                    assert shard.has_param(param)
                    assert dp_grads_status[var_name] == 4
                    dp_grads_status[var_name] = 5
189 190
        else:
            for input_name in op.desc.input_arg_names():
191 192
                if input_name in vars_status:
                    if vars_status[input_name] != 3:
193 194
                        raise ValueError("There should be a sync_comm op "
                                         "after allreduce the Var: {}".format(
195
                                             input_name))
196 197 198
                    raise ValueError(
                        "The reduce output grad [{}] should NOT be be used in Non-root rank.".
                        format(input_name))
199 200 201 202 203 204 205 206 207 208 209 210
                if input_name in dp_grads_status:
                    if dp_ring_id == -1:
                        if dp_grads_status[input_name] != 3:
                            raise ValueError("There should be a sync_comm op "
                                             "after allreduce the Var: {}".
                                             format(input_name))
                    else:
                        if dp_grads_status[input_name] != 5:
                            raise ValueError(
                                "The grad in shard should be allreduce and sync"
                                "twice before usage {}".format(input_name))

211
            for output_name in op.desc.output_arg_names():
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
                if output_name in vars_status and \
                    vars_status[output_name] == -1:
                    vars_status[output_name] = 0
                if output_name in dp_grads_status and  \
                    dp_grads_status[output_name] == -1:
                    dp_grads_status[output_name] = 0

    # check sharding with amp
    if idx_amp_allreduce != -1:
        assert idx_amp_allreduce > idx_last_grad_allreduce

    # check sharding with gradient_clip_by_global_norm
    if idx_gradient_clip_allreduce != -1:
        assert idx_gradient_clip_allreduce > idx_last_grad_allreduce

227 228 229
    return


J
JZ-LIANG 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243
def get_valid_op_role(block, insert_idx):
    """
    return OpRole.Forward or OpRole.Backward
    """
    op_role = block.ops[insert_idx].attr('op_role')
    if (insert_idx >= len(block.ops)) or (
            op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
        return OpRole.Backward
    if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
        return OpRole.Forward

    return get_valid_op_role(block, insert_idx + 1)


244 245 246 247
def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
    """
    _insert_sync_calc_op
    """
J
JZ-LIANG 已提交
248
    op_role = get_valid_op_role(block, insert_idx)
249 250 251 252 253 254 255 256 257
    block._insert_op_without_sync(
        insert_idx,
        type='c_sync_calc_stream',
        inputs={'X': calc_dep_vars},
        outputs={'Out': calc_dep_vars},
        attrs={OP_ROLE_KEY: op_role})
    return


258
def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars):
259
    """
260
    insert sync_comm_op for single var
261
    """
J
JZ-LIANG 已提交
262
    op_role = get_valid_op_role(block, insert_idx)
263 264 265 266 267 268 269 270 271 272 273 274 275 276
    block._insert_op_without_sync(
        insert_idx,
        type='c_sync_comm_stream',
        inputs={'X': comm_dep_vars},
        outputs={'Out': comm_dep_vars},
        attrs={'ring_id': ring_id,
               OP_ROLE_KEY: op_role})
    return 1


def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
    """
    insert sync_comm_op for vars
    """
277 278 279 280
    # NOTE (JZ-LIANG) to be check, may result undefined case 
    if len(comm_dep_vars) == 0:
        return 0

281 282 283 284 285 286 287 288 289
    op_role = get_valid_op_role(block, insert_idx)
    block._insert_op_without_sync(
        insert_idx,
        type='c_sync_comm_stream',
        inputs={'X': comm_dep_vars},
        outputs={'Out': comm_dep_vars},
        attrs={'ring_id': int(ring_id),
               OP_ROLE_KEY: op_role})
    return 1
290 291 292 293 294 295


def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
    """
    _add_fill_constant_ops
    """
J
JZ-LIANG 已提交
296
    op_role = get_valid_op_role(block, insert_idx)
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
    for broadcast_name in fill_constant_vars:
        broadcast_var = block.var(broadcast_name)
        block._insert_op_without_sync(
            insert_idx,
            type="fill_constant",
            outputs={"Out": broadcast_var.name},
            attrs={
                "shape": broadcast_var.shape,
                "dtype": broadcast_var.dtype,
                "value": 0.0,
                OP_ROLE_KEY: op_role
            })
    return


def insert_cast_ops(block, insert_idx, cast_ops):
    """
    _add_cast_ops
    """
J
JZ-LIANG 已提交
316
    op_role = get_valid_op_role(block, insert_idx)
317 318 319 320 321 322 323 324 325 326 327 328 329 330
    for fp16_name, fp32_name in cast_ops.items():
        block._insert_op_without_sync(
            insert_idx,
            type="cast",
            inputs={"X": fp32_name},
            outputs={"Out": fp16_name},
            attrs={
                "in_dtype": core.VarDesc.VarType.FP32,
                "out_dtype": core.VarDesc.VarType.FP16,
                OP_ROLE_KEY: op_role
            })
    return


331 332 333 334 335
def insert_allreduce_ops(block,
                         insert_idx,
                         ring_id,
                         allreduce_vars,
                         op_role=OpRole.Backward,
336 337
                         use_calc_stream=False,
                         user_defined_strategy=None):
338 339 340
    """
    _add_allreduce_ops
    """
341 342 343
    if len(allreduce_vars) == 0:
        return

344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
    if user_defined_strategy and user_defined_strategy.fuse_all_reduce_ops:
        insert_fused_allreduce_ops(block, insert_idx, ring_id, allreduce_vars,
                                   op_role, use_calc_stream,
                                   user_defined_strategy.fuse_grad_size_in_MB)
    else:
        for var in allreduce_vars:
            block._insert_op_without_sync(
                insert_idx,
                type='c_allreduce_sum',
                inputs={'X': var},
                outputs={'Out': var},
                attrs={
                    'ring_id': ring_id,
                    'use_calc_stream': use_calc_stream,
                    OP_ROLE_KEY: op_role
                })

    return


def insert_fused_allreduce_ops(block,
                               insert_idx,
                               ring_id,
                               allreduce_vars,
                               op_role=OpRole.Backward,
                               use_calc_stream=False,
                               fuse_grad_size_in_MB=32):
    segments = []
    cur_size = 0.
    last_dtype = None
374
    for var in allreduce_vars:
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
        real_var = block.var(var)
        var_size = get_var_size(real_var)
        if cur_size + var_size > fuse_grad_size_in_MB \
                or len(segments) == 0 \
                or real_var.dtype != last_dtype:
            segments.append([real_var])
            cur_size = var_size
            last_dtype = real_var.dtype
        else:
            segments[-1].append(real_var)
            cur_size += var_size

    fused_vars = []
    for segment in segments:
        tmp_var = block.create_var(
            name=unique_name.generate('FusedOutput_{}'.format(segment[0].name)),
            dtype=segment[0].dtype,
            persistable=False,
            stop_gradient=True)
        fused_vars.append(tmp_var)
395 396
        block._insert_op_without_sync(
            insert_idx,
397 398 399 400 401 402 403 404 405 406 407 408 409 410
            type="coalesce_tensor",
            inputs={"Input": segment},
            outputs={"Output": segment,
                     "FusedOutput": tmp_var},
            attrs={
                "copy_data": True,
                "use_align": True,
                "dtype": segment[0].dtype,
                OP_ROLE_KEY: op_role
            })

    for fused_var in fused_vars:
        block._insert_op_without_sync(
            insert_idx + len(fused_vars),
411
            type='c_allreduce_sum',
412 413
            inputs={'X': fused_var},
            outputs={'Out': fused_var},
414 415 416 417 418
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
                OP_ROLE_KEY: op_role
            })
419 420 421 422 423 424 425
        if not use_calc_stream:
            block._insert_op_without_sync(
                insert_idx + len(fused_vars),
                type='c_sync_calc_stream',
                inputs={'X': fused_var},
                outputs={'Out': fused_var},
                attrs={OP_ROLE_KEY: op_role})
426 427


428 429 430 431 432 433 434
def insert_reduce_ops(block,
                      insert_idx,
                      ring_id,
                      reduce_vars,
                      shard,
                      op_role=OpRole.Backward,
                      use_calc_stream=False):
435 436 437 438
    """
    _add_allreduce_ops
    """
    for var in reduce_vars:
439

440
        root_id = get_grad_device(var, shard)
Z
zhangchunle 已提交
441 442
        assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
            root_id)
443 444 445 446 447 448 449 450
        block._insert_op_without_sync(
            insert_idx,
            type='c_reduce_sum',
            inputs={'X': var},
            outputs={'Out': var},
            attrs={
                'ring_id': ring_id,
                'root_id': root_id,
451 452
                'use_calc_stream': use_calc_stream,
                OP_ROLE_KEY: op_role
453 454 455 456
            })
    return


457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
def get_grad_device(grad_name, shard):
    assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
        grad_name)
    base_name = None
    # mind the traversal order 
    possible_suffixes = [
        '.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
    ]
    for suffix in possible_suffixes:
        if suffix in grad_name:
            base_name = re.sub(suffix, '', grad_name)
            break

    assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
        base_name)

    return shard.global_param2device[base_name]


B
Baibaifan 已提交
476
def get_first_check_finite_and_unscale_op_idx(block, raise_error=True):
477 478 479 480 481

    for idx, op in enumerate(block.ops):
        if op.type == "check_finite_and_unscale":
            return idx

B
Baibaifan 已提交
482 483 484 485 486 487
    if raise_error:
        raise ValueError(
            "amp is turned on but check_finite_and_unscale op does not exist in main block"
        )

    return -1
488 489


490
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
491 492 493
    """
    _add_broadcast_ops
    """
J
JZ-LIANG 已提交
494
    op_role = get_valid_op_role(block, insert_idx)
495 496 497 498 499 500 501 502 503 504 505
    for broadcast_name, root_device in broadcast2root:
        block._insert_op_without_sync(
            insert_idx,
            type='c_broadcast',
            inputs={'X': broadcast_name},
            outputs={'Out': broadcast_name},
            attrs={
                'ring_id': ring_id,
                'root': root_device,
                OP_ROLE_KEY: op_role
            })
506

507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
    return


DtypeToSize = {
    core.VarDesc.VarType.FP16: 2,
    core.VarDesc.VarType.FP32: 4,
    core.VarDesc.VarType.FP64: 8,
    core.VarDesc.VarType.INT16: 2,
    core.VarDesc.VarType.INT32: 4,
    core.VarDesc.VarType.INT64: 8,
    core.VarDesc.VarType.BOOL: 1,
    core.VarDesc.VarType.UINT8: 1,
}


def get_var_size(param):
    """
    input:
        - param: var
    return:
J
JZ-LIANG 已提交
527
        var size in MB
528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
    """
    assert -1 not in param.shape
    return reduce(lambda x, y: x * y,
                  param.shape) * DtypeToSize[param.dtype] / 1024.0 / 1024.0


def insert_scale_loss_grad_ops(block, scale=1.0):
    '''
    In order to keep the learning rate consistent in different numbers of
    training workers, we scale the loss grad by the number of workers
    '''
    for idx, op in reversed(list(enumerate(block.ops))):
        if is_loss_grad_op(op):
            loss_grad_var = block.vars[op.output_arg_names[0]]
            block._insert_op_without_sync(
                idx + 1,
                type='scale',
                inputs={'X': loss_grad_var},
                outputs={'Out': loss_grad_var},
                attrs={'scale': scale,
                       OP_ROLE_KEY: OpRole.Backward})
549
            break
J
JZ-LIANG 已提交
550 551 552 553 554 555 556 557 558 559 560 561


def comm_analyse(main_program):
    """
    Analyse the parameter size that need to be broadcast/allreduce during sharding training 
    """
    reduce_vars = {}
    broadcast_vars = {}
    block = main_program.global_block()
    for op in block.ops:
        if op.type == "c_broadcast":
            var_name = op.desc.input_arg_names()[0]
J
JZ-LIANG 已提交
562 563 564
            # convert MB to KB
            broadcast_vars[var_name] = get_var_size(block.var(
                var_name)) * 1024.0
J
JZ-LIANG 已提交
565 566
        elif op.type == "c_allreduce_sum":
            var_name = op.desc.input_arg_names()[0]
J
JZ-LIANG 已提交
567
            reduce_vars[var_name] = get_var_size(block.var(var_name)) * 1024.0
J
JZ-LIANG 已提交
568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593

    varsize_count = {}
    gap = 1

    for k, v in broadcast_vars.items():
        print("broadcast: {}: {} KB".format(k, v))
        if (int(v / gap) in varsize_count):
            varsize_count[int(v / gap)] += 1
        else:
            varsize_count[int(v / gap)] = 1

    for k, v in reduce_vars.items():
        print("allreduce: {}: {} KB".format(k, v))
        if (int(v / gap) in varsize_count):
            varsize_count[int(v / gap)] += 1
        else:
            varsize_count[int(v / gap)] = 1

    with open("nccl_size.txt", 'w') as f:
        sorted_varsize = sorted(varsize_count.items(), key=lambda x: x[0])
        for varsize, count in sorted_varsize:
            print("NCCL size {}~{} KB: {}".format(varsize, varsize + 1, count))
            f.write("NCCL size {}~{} KB: {}\n".format(varsize, varsize + 1,
                                                      count))


594
def add_sync_comm(program, sharding_ring_id):
J
JZ-LIANG 已提交
595 596 597 598 599 600
    """
    When clone a test prog by clone from the sharding main prog, 
    part of the sync_comm op maybe be pruned by mistake, this function
    add the sync_comm op for the test prog.

    """
601
    #NOTE (liangjianzhong): only support one comm stream by now, use more than one
J
JZ-LIANG 已提交
602 603
    # comm streams will cause error. should be revise in future.

604
    assert sharding_ring_id >= 0, "sharding_ring_id should larger than zero"
J
JZ-LIANG 已提交
605 606 607 608 609 610 611 612 613 614
    block = program.global_block()
    not_sync_vars = set([])
    for op in block.ops:
        if op.type in ["c_broadcast", "c_allreduce"]:
            for input_name in op.desc.input_arg_names():
                not_sync_vars.add(input_name)
        if op.type == "c_sync_comm_stream":
            for input_name in op.desc.input_arg_names():
                not_sync_vars.remove(input_name)
    if not_sync_vars:
615 616 617 618 619 620 621 622
        block.append_op(
            type='c_sync_comm_stream',
            inputs={'X': list(not_sync_vars)},
            outputs={'Out': list(not_sync_vars)},
            attrs={
                'ring_id': sharding_ring_id,
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            })
J
JZ-LIANG 已提交
623 624 625
    return


J
JZ-LIANG 已提交
626
def save_persistables(exe, dirname, main_program, filename=None):
J
JZ-LIANG 已提交
627 628 629 630 631
    """
    When use sharding, part of persistable vars are unique and are partitioned in different ranks,
    and part of persistable vars are duplicated and exist in all the ranks with different values.
    This function handles the model saving for sharding training.
    """
632 633
    # TODO (JZ-LIANG) revise this for uniform mixed parallelism
    if main_program._pipeline_opt:
L
lilong12 已提交
634
        main_program = main_program._pipeline_opt['section_program']
J
JZ-LIANG 已提交
635 636

    def is_opt_vars(var):
637
        # NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
J
JZ-LIANG 已提交
638 639 640 641 642 643 644 645 646 647
        # now only Momentum and adam are compatible with sharding
        checks = [
            "_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0",
            "_velocity_0"
        ]
        for check in checks:
            if var.name.endswith(check):
                return True
        return False

648 649 650 651 652
    def is_gradient_merge_vars(var):
        # NOTE(JZ-LIANG): to revise save/load logic in framework instead of write this naive rule

        return var.name.endswith("@GradiantMerge")

J
JZ-LIANG 已提交
653 654 655 656 657
    def is_trainable(var):
        return isinstance(var,
                          paddle.fluid.framework.Parameter) and var.trainable

    def sharding_predicate(var):
658 659
        return is_trainable(var) or is_opt_vars(var) or is_gradient_merge_vars(
            var)
J
JZ-LIANG 已提交
660 661 662 663 664 665 666 667 668 669 670 671 672

    if int(os.environ.get('PADDLE_TRAINER_ID', 0)) == 0:
        paddle.fluid.io.save_persistables(
            exe, dirname, main_program=main_program, filename=None)
    else:
        paddle.fluid.io.save_vars(
            exe,
            dirname,
            main_program=main_program,
            predicate=sharding_predicate,
            filename=None)

    return
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711


def get_grad_device(grad_name, shard):
    assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
        grad_name)
    base_name = None
    # mind the traversal order 
    possible_suffixes = ['.cast_fp16@GRAD', '@GRAD']
    for suffix in possible_suffixes:
        if suffix in grad_name:
            base_name = re.sub(suffix, '', grad_name)
            break

    assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
        base_name)

    return shard.global_param2device[base_name]


def append_naive_sync(block, sync_var, ring_id):
    # NOTE (JZ-LIANG) update this to use barrier sync for more elegent logic
    # sync within global 
    block.append_op(
        type="fill_constant",
        outputs={"Out": sync_var},
        attrs={
            "shape": sync_var.shape,
            "dtype": sync_var.dtype,
            "value": int(1),
        })
    block.append_op(
        type='c_allreduce_sum',
        inputs={'X': sync_var},
        outputs={'Out': sync_var},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': True,
            OP_ROLE_KEY: OpRole.Forward
        })
712 713 714 715 716
    block.append_op(
        type='c_sync_calc_stream',
        inputs={'X': [sync_var]},
        outputs={'Out': [sync_var]},
        attrs={OP_ROLE_KEY: OpRole.Forward})