hybrid_parallel_util.py 3.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import numpy as np
import warnings

from paddle import framework
import paddle
from paddle.fluid import core
22
from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups
23
from collections import OrderedDict
24
from .log_util import logger
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40


def _apply_collective_grads(parameters, comm_group):
    grad_var_set = set()
    grad_vars = []
    sparse_grad_vars = []

    for param in parameters:
        if param.trainable and (param._grad_ivar() is not None):
            g_var = param._grad_ivar()
            assert not g_var._is_sparse(
            ), "Now, it doesn't support sparse parameters"
            grad_vars.append(g_var)
            assert g_var not in grad_var_set
            grad_var_set.add(g_var)

41
    coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024)
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

    for coalesced_grad, _, _ in coalesced_grads_and_vars:
        # need to div nranks
        coalesced_grad = coalesced_grad / comm_group.nranks
        paddle.distributed.all_reduce(coalesced_grad, group=comm_group)

    _split_tensors(coalesced_grads_and_vars)


def broadcast_input_data(hcg, *inputs, **kwargs):
    model_parallel_group = hcg.get_model_parallel_group()
    src_rank = hcg.get_model_parallel_group_src_rank()

    for input_ in inputs:
        if isinstance(input_, core.VarBase):
            with framework.no_grad():
                paddle.distributed.broadcast(
                    input_,
                    src=src_rank,
                    group=model_parallel_group,
                    use_calc_stream=True)
        else:
64
            logger.error("it doesn't support data type {}".format(type(input_)))
65 66 67 68 69 70 71 72 73 74 75

    for k, v in kwargs.items():
        if isinstance(v, core.VarBase):
            with framework.no_grad():
                paddle.distributed.broadcast(
                    v,
                    src=src_rank,
                    group=model_parallel_group,
                    use_calc_stream=True)
            kwargs[k] = v
        else:
76
            logger.error("it doesn't support data type {}".format(type(v)))
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    return inputs, kwargs


def broadcast_mp_parameters(model, hcg):
    model_parallel_group = hcg.get_model_parallel_group()
    src_rank = hcg.get_model_parallel_group_src_rank()
    sync_params_buffers(
        model, model_parallel_group, src_rank, is_model_parallel=True)


def broadcast_dp_parameters(model, hcg):
    data_parallel_group = hcg.get_data_parallel_group()
    src_rank = hcg.get_data_parallel_group_src_rank()
    sync_params_buffers(
        model, data_parallel_group, src_rank, is_model_parallel=False)


def fused_allreduce_gradients(parameter_list, hcg):
    data_parallel_group = hcg.get_data_parallel_group()
96
    logger.debug("dp start fuse allreduce gradients")
97 98
    with framework.no_grad():
        _apply_collective_grads(parameter_list, data_parallel_group)