utils.py 17.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
J
JZ-LIANG 已提交
14
import paddle
15 16 17 18 19 20
from paddle.fluid import core
from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY

import re
J
JZ-LIANG 已提交
21
import os
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80


def check_broadcast(block):
    """
    if a var is broadcasted, it should have a sync_comm before
    this var is used, if not, raise error.
    if the broadcasted var has a fill_constant op, the fill_constant
    op should stay forward before the broadcast op, and before a
    sync_calc op. Otherwise, raise error.
    """
    broadcast_vars = {}
    for idx, op in enumerate(block.ops):
        if op.type == "c_broadcast":
            var_name = op.desc.input_arg_names()[0]
            if "@BroadCast" in var_name:
                if var_name in broadcast_vars:
                    raise ValueError("var_name areadly exist: {}"
                                     "the old pos is {}, the new pos is {}".
                                     format(var_name, broadcast_vars[var_name][
                                         "broadcast_pos"], idx))
                broadcast_vars[var_name] = {
                    "fill_constant_pos": -1,
                    "broadcast_pos": idx,
                }

    for idx, op in enumerate(block.ops):
        if op.type == "fill_constant":
            var_name = op.desc.output_arg_names()[0]
            if var_name in broadcast_vars:
                broadcast_vars[var_name]["fill_constant_pos"] = idx
            continue

    last_sync_comm_op_idx = -1
    last_sync_calc_op_idx = -1
    for idx, op in enumerate(block.ops):
        if op.type == "c_sync_comm_stream":
            last_sync_comm_op_idx = idx
            continue
        if op.type == "c_sync_calc_stream":
            last_sync_calc_op_idx = idx
            continue
        if op.type == "c_broadcast":
            var_name = op.desc.input_arg_names()[0]
            if "@BroadCast" in var_name:
                if broadcast_vars[var_name]["fill_constant_pos"] != -1:
                    assert (last_sync_calc_op_idx != -1)
                    assert (broadcast_vars[var_name]["fill_constant_pos"] <
                            last_sync_calc_op_idx)
                    assert (last_sync_calc_op_idx < idx)
                continue
        for input_name in op.desc.input_arg_names():
            if input_name in broadcast_vars:
                assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
                assert (broadcast_vars[input_name]["broadcast_pos"] <
                        last_sync_comm_op_idx)
                assert (last_sync_comm_op_idx < idx)
    return


81
def check_allreduce_sum(block, shard, dp_ring_id=-1):
82
    """
83 84 85 86 87 88 89 90 91
    the op order should be:
        grad:
            - 0: op that generate Var
            - 1: sync_calc
            - 2: allreduce_sum_sharding
            - 3: sync_comm
            - 4: allreuce_sum_dp (dp_grads)
            - 5: sync_comm (dp_grads)
            - 6: op that use Var (dp_grads & sum)
92
    """
93 94 95 96 97 98
    vars_status = {}
    dp_grads_status = {}
    idx_last_grad_allreduce = -1
    idx_amp_allreduce = -1
    idx_gradient_clip_allreduce = -1
    for idx, op in enumerate(block.ops):
99
        if op.type == "c_allreduce_sum":
100
            ring_id = op.desc.attr("ring_id")
101
            var_name = op.desc.input_arg_names()[0]
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
            param = var_name.split("@")[0]

            assert 'sum' in var_name or ("@GRAD" in var_name)
            if 'sum' in var_name or (not shard.has_param(param)):
                vars_status[var_name] = -1
            else:
                dp_grads_status[var_name] = -1

            if ring_id != 0:
                assert shard.has_param(param)
                assert ring_id == dp_ring_id

            if "sum" in var_name:
                idx_amp_allreduce = idx
            elif "@GRAD":
                idx_last_grad_allreduce = idx

        if op.type == "c_allreduce_max":
            idx_gradient_clip_allreduce = idx
121 122 123

    for op in block.ops:
        if op.type == "c_sync_calc_stream":
124 125 126 127 128 129 130 131
            for var_name in vars_status:
                if var_name in vars_status and vars_status[var_name] == 0:
                    vars_status[var_name] = 1
            for var_name in dp_grads_status:
                if var_name in dp_grads_status and dp_grads_status[
                        var_name] == 0:
                    dp_grads_status[var_name] = 1

132 133
        elif op.type == "c_allreduce_sum":
            var_name = op.desc.input_arg_names()[0]
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
            ring_id = op.desc.attr("ring_id")
            if ring_id == 0:
                if var_name in vars_status:
                    _status = vars_status[var_name]
                else:
                    _status = dp_grads_status[var_name]
                if _status == -1:
                    raise ValueError("{} is not generated, but you are"
                                     "trying to all-reduce it".format(var_name))
                if _status == 0:
                    raise ValueError("There should be a sync_calc op "
                                     "after generate Var: {} and before the"
                                     "c_allreduce_sum op".format(var_name))
                assert (_status == 1)
                if var_name in vars_status:
                    vars_status[var_name] = 2
                else:
                    dp_grads_status[var_name] = 2
            else:
                assert ring_id == dp_ring_id
                param = var_name.split("@")[0]
                assert shard.has_param(param)
                assert dp_grads_status[var_name] == 3
                dp_grads_status[var_name] = 4

159
        elif op.type == "c_sync_comm_stream":
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
            var_name = op.desc.input_arg_names()[0]
            ring_id = op.desc.attr("ring_id")
            if ring_id == 0:
                for var_name in op.desc.input_arg_names():
                    if var_name in vars_status:
                        assert vars_status[var_name] == 2
                        vars_status[var_name] = 3
                    elif var_name in dp_grads_status:
                        assert dp_grads_status[var_name] == 2
                        dp_grads_status[var_name] = 3
            else:
                for var_name in op.desc.input_arg_names():
                    param = var_name.split("@")[0]
                    assert ring_id == dp_ring_id
                    assert shard.has_param(param)
                    assert dp_grads_status[var_name] == 4
                    dp_grads_status[var_name] = 5
177 178
        else:
            for input_name in op.desc.input_arg_names():
179 180
                if input_name in vars_status:
                    if vars_status[input_name] != 3:
181 182
                        raise ValueError("There should be a sync_comm op "
                                         "after allreduce the Var: {}".format(
183 184 185 186 187 188 189 190 191 192 193 194 195
                                             input_name))
                if input_name in dp_grads_status:
                    if dp_ring_id == -1:
                        if dp_grads_status[input_name] != 3:
                            raise ValueError("There should be a sync_comm op "
                                             "after allreduce the Var: {}".
                                             format(input_name))
                    else:
                        if dp_grads_status[input_name] != 5:
                            raise ValueError(
                                "The grad in shard should be allreduce and sync"
                                "twice before usage {}".format(input_name))

196
            for output_name in op.desc.output_arg_names():
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
                if output_name in vars_status and \
                    vars_status[output_name] == -1:
                    vars_status[output_name] = 0
                if output_name in dp_grads_status and  \
                    dp_grads_status[output_name] == -1:
                    dp_grads_status[output_name] = 0

    # check sharding with amp
    if idx_amp_allreduce != -1:
        assert idx_amp_allreduce > idx_last_grad_allreduce

    # check sharding with gradient_clip_by_global_norm
    if idx_gradient_clip_allreduce != -1:
        assert idx_gradient_clip_allreduce > idx_last_grad_allreduce

212 213 214
    return


J
JZ-LIANG 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228
def get_valid_op_role(block, insert_idx):
    """
    return OpRole.Forward or OpRole.Backward
    """
    op_role = block.ops[insert_idx].attr('op_role')
    if (insert_idx >= len(block.ops)) or (
            op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
        return OpRole.Backward
    if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
        return OpRole.Forward

    return get_valid_op_role(block, insert_idx + 1)


229 230 231 232
def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
    """
    _insert_sync_calc_op
    """
J
JZ-LIANG 已提交
233
    op_role = get_valid_op_role(block, insert_idx)
234 235 236 237 238 239 240 241 242
    block._insert_op_without_sync(
        insert_idx,
        type='c_sync_calc_stream',
        inputs={'X': calc_dep_vars},
        outputs={'Out': calc_dep_vars},
        attrs={OP_ROLE_KEY: op_role})
    return


243
def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars):
244
    """
245
    insert sync_comm_op for single var
246
    """
J
JZ-LIANG 已提交
247
    op_role = get_valid_op_role(block, insert_idx)
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
    block._insert_op_without_sync(
        insert_idx,
        type='c_sync_comm_stream',
        inputs={'X': comm_dep_vars},
        outputs={'Out': comm_dep_vars},
        attrs={'ring_id': ring_id,
               OP_ROLE_KEY: op_role})
    return 1


def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
    """
    insert sync_comm_op for vars
    """
    op_role = get_valid_op_role(block, insert_idx)
    block._insert_op_without_sync(
        insert_idx,
        type='c_sync_comm_stream',
        inputs={'X': comm_dep_vars},
        outputs={'Out': comm_dep_vars},
        attrs={'ring_id': int(ring_id),
               OP_ROLE_KEY: op_role})
    return 1
271 272 273 274 275 276


def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
    """
    _add_fill_constant_ops
    """
J
JZ-LIANG 已提交
277
    op_role = get_valid_op_role(block, insert_idx)
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
    for broadcast_name in fill_constant_vars:
        broadcast_var = block.var(broadcast_name)
        block._insert_op_without_sync(
            insert_idx,
            type="fill_constant",
            outputs={"Out": broadcast_var.name},
            attrs={
                "shape": broadcast_var.shape,
                "dtype": broadcast_var.dtype,
                "value": 0.0,
                OP_ROLE_KEY: op_role
            })
    return


def insert_cast_ops(block, insert_idx, cast_ops):
    """
    _add_cast_ops
    """
J
JZ-LIANG 已提交
297
    op_role = get_valid_op_role(block, insert_idx)
298 299 300 301 302 303 304 305 306 307 308 309 310 311
    for fp16_name, fp32_name in cast_ops.items():
        block._insert_op_without_sync(
            insert_idx,
            type="cast",
            inputs={"X": fp32_name},
            outputs={"Out": fp16_name},
            attrs={
                "in_dtype": core.VarDesc.VarType.FP32,
                "out_dtype": core.VarDesc.VarType.FP16,
                OP_ROLE_KEY: op_role
            })
    return


312
def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
313 314 315 316 317 318 319 320 321 322 323
    """
    _add_allreduce_ops
    """
    for var in allreduce_vars:
        block._insert_op_without_sync(
            insert_idx,
            type='c_allreduce_sum',
            inputs={'X': var},
            outputs={'Out': var},
            attrs={'ring_id': ring_id,
                   OP_ROLE_KEY: OpRole.Backward})
324

325 326 327
    return


328
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
329 330 331
    """
    _add_broadcast_ops
    """
J
JZ-LIANG 已提交
332
    op_role = get_valid_op_role(block, insert_idx)
333 334 335 336 337 338 339 340 341 342 343
    for broadcast_name, root_device in broadcast2root:
        block._insert_op_without_sync(
            insert_idx,
            type='c_broadcast',
            inputs={'X': broadcast_name},
            outputs={'Out': broadcast_name},
            attrs={
                'ring_id': ring_id,
                'root': root_device,
                OP_ROLE_KEY: op_role
            })
344

345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
    return


DtypeToSize = {
    core.VarDesc.VarType.FP16: 2,
    core.VarDesc.VarType.FP32: 4,
    core.VarDesc.VarType.FP64: 8,
    core.VarDesc.VarType.INT16: 2,
    core.VarDesc.VarType.INT32: 4,
    core.VarDesc.VarType.INT64: 8,
    core.VarDesc.VarType.BOOL: 1,
    core.VarDesc.VarType.UINT8: 1,
}


def get_var_size(param):
    """
    input:
        - param: var
    return:
J
JZ-LIANG 已提交
365
        var size in MB
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
    """
    assert -1 not in param.shape
    return reduce(lambda x, y: x * y,
                  param.shape) * DtypeToSize[param.dtype] / 1024.0 / 1024.0


def insert_scale_loss_grad_ops(block, scale=1.0):
    '''
    In order to keep the learning rate consistent in different numbers of
    training workers, we scale the loss grad by the number of workers
    '''
    for idx, op in reversed(list(enumerate(block.ops))):
        if is_loss_grad_op(op):
            loss_grad_var = block.vars[op.output_arg_names[0]]
            block._insert_op_without_sync(
                idx + 1,
                type='scale',
                inputs={'X': loss_grad_var},
                outputs={'Out': loss_grad_var},
                attrs={'scale': scale,
                       OP_ROLE_KEY: OpRole.Backward})
J
JZ-LIANG 已提交
387 388 389 390 391 392 393 394 395 396 397 398


def comm_analyse(main_program):
    """
    Analyse the parameter size that need to be broadcast/allreduce during sharding training 
    """
    reduce_vars = {}
    broadcast_vars = {}
    block = main_program.global_block()
    for op in block.ops:
        if op.type == "c_broadcast":
            var_name = op.desc.input_arg_names()[0]
J
JZ-LIANG 已提交
399 400 401
            # convert MB to KB
            broadcast_vars[var_name] = get_var_size(block.var(
                var_name)) * 1024.0
J
JZ-LIANG 已提交
402 403
        elif op.type == "c_allreduce_sum":
            var_name = op.desc.input_arg_names()[0]
J
JZ-LIANG 已提交
404
            reduce_vars[var_name] = get_var_size(block.var(var_name)) * 1024.0
J
JZ-LIANG 已提交
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430

    varsize_count = {}
    gap = 1

    for k, v in broadcast_vars.items():
        print("broadcast: {}: {} KB".format(k, v))
        if (int(v / gap) in varsize_count):
            varsize_count[int(v / gap)] += 1
        else:
            varsize_count[int(v / gap)] = 1

    for k, v in reduce_vars.items():
        print("allreduce: {}: {} KB".format(k, v))
        if (int(v / gap) in varsize_count):
            varsize_count[int(v / gap)] += 1
        else:
            varsize_count[int(v / gap)] = 1

    with open("nccl_size.txt", 'w') as f:
        sorted_varsize = sorted(varsize_count.items(), key=lambda x: x[0])
        for varsize, count in sorted_varsize:
            print("NCCL size {}~{} KB: {}".format(varsize, varsize + 1, count))
            f.write("NCCL size {}~{} KB: {}\n".format(varsize, varsize + 1,
                                                      count))


J
JZ-LIANG 已提交
431
def add_sync_comm(program, dist_strategy):
J
JZ-LIANG 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
    """
    When clone a test prog by clone from the sharding main prog, 
    part of the sync_comm op maybe be pruned by mistake, this function
    add the sync_comm op for the test prog.

    """
    #NOTE (liangjianzhong): only support one comm stream by now, use more than one 
    # comm streams will cause error. should be revise in future.

    block = program.global_block()
    not_sync_vars = set([])
    for op in block.ops:
        if op.type in ["c_broadcast", "c_allreduce"]:
            for input_name in op.desc.input_arg_names():
                not_sync_vars.add(input_name)
        if op.type == "c_sync_comm_stream":
            for input_name in op.desc.input_arg_names():
                not_sync_vars.remove(input_name)
    if not_sync_vars:
        for nccl_id in range(dist_strategy.nccl_comm_num):
            block.append_op(
                type='c_sync_comm_stream',
                inputs={'X': list(not_sync_vars)},
                outputs={'Out': list(not_sync_vars)},
                attrs={
                    'ring_id': nccl_id,
                    'op_role': core.op_proto_and_checker_maker.OpRole.Forward
                })
    return


J
JZ-LIANG 已提交
463
def save_persistables(exe, dirname, main_program, filename=None):
J
JZ-LIANG 已提交
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500
    """
    When use sharding, part of persistable vars are unique and are partitioned in different ranks,
    and part of persistable vars are duplicated and exist in all the ranks with different values.
    This function handles the model saving for sharding training.
    """

    def is_opt_vars(var):
        # NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
        # now only Momentum and adam are compatible with sharding
        checks = [
            "_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0",
            "_velocity_0"
        ]
        for check in checks:
            if var.name.endswith(check):
                return True
        return False

    def is_trainable(var):
        return isinstance(var,
                          paddle.fluid.framework.Parameter) and var.trainable

    def sharding_predicate(var):
        return is_trainable(var) or is_opt_vars(var)

    if int(os.environ.get('PADDLE_TRAINER_ID', 0)) == 0:
        paddle.fluid.io.save_persistables(
            exe, dirname, main_program=main_program, filename=None)
    else:
        paddle.fluid.io.save_vars(
            exe,
            dirname,
            main_program=main_program,
            predicate=sharding_predicate,
            filename=None)

    return