hybrid_parallel_util.py 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import numpy as np

from paddle import framework
import paddle
from paddle.fluid import core
21
from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups
22
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
23
from collections import OrderedDict
24
from .log_util import logger
25

26 27
__all__ = []

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

def _apply_collective_grads(parameters, comm_group):
    grad_var_set = set()
    grad_vars = []
    sparse_grad_vars = []

    for param in parameters:
        if param.trainable and (param._grad_ivar() is not None):
            g_var = param._grad_ivar()
            assert not g_var._is_sparse(
            ), "Now, it doesn't support sparse parameters"
            grad_vars.append(g_var)
            assert g_var not in grad_var_set
            grad_var_set.add(g_var)

43
    coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024)
44

45 46
    nranks = paddle.distributed.get_world_size(
    ) if comm_group is None else comm_group.nranks
47 48
    for coalesced_grad, _, _ in coalesced_grads_and_vars:
        # need to div nranks
49
        div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype)
50 51
        paddle.fluid.framework._dygraph_tracer().trace_op(
            type="elementwise_div",
52 53 54 55
            inputs={
                'X': coalesced_grad,
                'Y': div_factor
            },
56 57
            outputs={'Out': coalesced_grad},
            attrs={'axis': -1})
58
        paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
59

60 61 62
    _split_tensors(coalesced_grads_and_vars)


63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
def _apply_collective_grads_eager(parameters, comm_group):
    grad_var_set = set()
    grad_vars = []

    for param in parameters:
        if param.trainable and (param._grad_ivar() is not None):
            g_var = param._grad_ivar()
            assert not g_var.is_sparse(
            ), "Now, it doesn't support sparse parameters"
            grad_vars.append(g_var)
            assert g_var not in grad_var_set
            grad_var_set.add(g_var)

    coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024)

78 79
    nranks = paddle.distributed.get_world_size(
    ) if comm_group is None else comm_group.nranks
80
    for coalesced_grad, _, _ in coalesced_grads_and_vars:
81
        # need to div nranks
82
        coalesced_grad.scale_(1.0 / nranks)
83 84 85 86 87
        paddle.distributed.all_reduce(coalesced_grad, group=comm_group)

    _split_tensors(coalesced_grads_and_vars)


88
def _broadcast_data_help(data, shape, dtype, hcg):
89 90
    model_parallel_group = hcg.get_model_parallel_group()
    src_rank = hcg.get_model_parallel_group_src_rank()
91 92 93
    mp_rank = hcg.get_model_parallel_rank()

    shape_gpu = paddle.to_tensor(shape, dtype="int32")
94 95 96 97
    paddle.distributed.broadcast(shape_gpu,
                                 src=src_rank,
                                 group=model_parallel_group,
                                 use_calc_stream=True)
98 99 100 101 102 103

    if mp_rank != 0:
        input_data = paddle.zeros(shape_gpu, dtype=dtype)
    else:
        input_data = data

104 105 106 107
    paddle.distributed.broadcast(input_data,
                                 src=src_rank,
                                 group=model_parallel_group,
                                 use_calc_stream=True)
108

109 110 111 112 113 114 115 116 117
    if mp_rank != 0:
        if in_dygraph_mode():
            data._clear_data()
            input_data._share_buffer_to(data)
        else:
            data.value().get_tensor()._clear()
            data.value().get_tensor()._share_data_with(
                input_data.value().get_tensor())

118 119

def broadcast_input_data(hcg, *inputs, **kwargs):
120
    cur_device = paddle.get_device()
121
    for v in inputs:
122
        if isinstance(v, (core.VarBase, core.eager.Tensor)):
123
            with framework.no_grad():
124 125 126 127 128
                if "gpu" in cur_device and in_dygraph_mode() \
                    and not v.place.is_gpu_place():
                    v_gpu = v.cuda(int(cur_device.split(":")[1]))
                    v._clear_data()
                    v_gpu._share_buffer_to(v)
129
                _broadcast_data_help(v, v.shape, v.dtype, hcg)
130
        else:
131
            logger.error("it doesn't support data type {}".format(type(v)))
132 133

    for k, v in kwargs.items():
134
        if isinstance(v, (core.VarBase, core.eager.Tensor)):
135
            with framework.no_grad():
136 137 138 139 140
                if "gpu" in cur_device and in_dygraph_mode() \
                    and not v.place.is_gpu_place():
                    v_gpu = v.cuda(int(cur_device.split(":")[1]))
                    v._clear_data()
                    v_gpu._share_buffer_to(v)
141
                _broadcast_data_help(v, v.shape, v.dtype, hcg)
142 143
            kwargs[k] = v
        else:
144
            logger.error("it doesn't support data type {}".format(type(v)))
145 146 147 148 149 150
    return inputs, kwargs


def broadcast_mp_parameters(model, hcg):
    model_parallel_group = hcg.get_model_parallel_group()
    src_rank = hcg.get_model_parallel_group_src_rank()
151 152 153 154
    sync_params_buffers(model,
                        model_parallel_group,
                        src_rank,
                        is_model_parallel=True)
155 156 157 158 159


def broadcast_dp_parameters(model, hcg):
    data_parallel_group = hcg.get_data_parallel_group()
    src_rank = hcg.get_data_parallel_group_src_rank()
160 161 162 163
    sync_params_buffers(model,
                        data_parallel_group,
                        src_rank,
                        is_model_parallel=False)
164 165 166


def fused_allreduce_gradients(parameter_list, hcg):
H
Haohongxiang 已提交
167 168 169 170 171 172
    data_parallel_group = None if hcg is None else hcg.get_data_parallel_group()
    logger.debug("dp start fuse allreduce gradients")
    apply_func = _apply_collective_grads_eager if in_dygraph_mode(
    ) else _apply_collective_grads
    with framework.no_grad():
        apply_func(parameter_list, data_parallel_group)
J
JZ-LIANG 已提交
173 174 175 176


def sharding_reduce_gradients(parameter_list, hcg):
    # TODO allreduce --> reduce
177
    # TODO merge grad / nrank with dp
J
JZ-LIANG 已提交
178 179 180 181 182 183
    logger.debug("sharding start gradients sync")
    with framework.no_grad():

        sharding_nrank = hcg.get_sharding_parallel_group().nranks
        for param in parameter_list:
            if param.trainable and (param._grad_ivar() is not None):
184 185 186 187 188 189 190 191 192
                if in_dygraph_mode():
                    param.grad.scale_(1.0 / sharding_nrank)
                    paddle.distributed.all_reduce(
                        param.grad,
                        group=hcg.get_sharding_parallel_group(),
                        use_calc_stream=True)

                elif _in_legacy_dygraph():
                    g_var = param._grad_ivar()
193
                    # need use trace_op to allreduce
194 195 196 197 198 199 200 201 202 203 204 205
                    # paddle.distributed.all_reduce(
                    #     g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
                    paddle.fluid.framework._dygraph_tracer().trace_op(
                        type="c_allreduce_sum",
                        inputs={'X': g_var},
                        outputs={'Out': g_var},
                        attrs={
                            'ring_id': hcg.get_sharding_parallel_group().id,
                            'use_calc_stream': True
                        })

                    # grad / sharding_rank
206 207
                    div_factor = paddle.to_tensor(sharding_nrank,
                                                  dtype=g_var.dtype)
208 209
                    paddle.fluid.framework._dygraph_tracer().trace_op(
                        type="elementwise_div",
210 211 212 213
                        inputs={
                            'X': g_var,
                            'Y': div_factor
                        },
214 215
                        outputs={'Out': g_var},
                        attrs={'axis': -1})
J
JZ-LIANG 已提交
216 217 218 219 220 221 222


def broadcast_sharding_parameters(model, hcg):
    # TODO TO save memory, use un-fused broadcast to avoid potentional OOM
    logger.debug("sharding start init parameters sync")
    sharding_parallel_group = hcg.get_sharding_parallel_group()
    src_rank = hcg.get_sharding_parallel_group_src_rank()
223 224 225 226
    sync_params_buffers(model,
                        sharding_parallel_group,
                        src_rank,
                        is_model_parallel=False)