From 5ccc49e7a6dc14c462bc77d94530524ae7ab2d6e Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Thu, 2 Jun 2022 17:41:29 +0800 Subject: [PATCH] support eager dygraph in moe_layer (#43168) --- python/paddle/distributed/collective.py | 4 +- python/paddle/distributed/parallel.py | 3 + .../distributed/models/moe/moe_layer.py | 199 ++++++++++++++++-- .../incubate/distributed/models/moe/utils.py | 17 +- 4 files changed, 195 insertions(+), 28 deletions(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 5f481bd0dc..fab6674b65 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -405,9 +405,7 @@ def new_group(ranks=None, backend=None): # TODO(shenliang03): This is a temporary solution to solve the problem of # hang caused by tcp - tmp = paddle.to_tensor([1], dtype="int32") - paddle.distributed.all_reduce(tmp, group=group, use_calc_stream=True) - paddle.distributed.wait(tmp) + paddle.distributed.barrier(group=group) return group if not backend: diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 8cd6c4647d..f8c5b79e33 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -19,6 +19,7 @@ from multiprocessing import Process # noqa: F401 from multiprocessing import Manager # noqa: F401 import time import sys +import paddle from paddle import compat as cpt @@ -259,6 +260,8 @@ def init_parallel_env(): _set_group_map_by_name(_default_group_name, group) _set_group_map(0, group) parallel_helper._set_parallel_ctx(True) + + paddle.distributed.barrier(group=group) return group node_num = set([i.split(":")[0] for i in parallel_env.trainer_endpoints]) diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index ba22ffee3e..8ac0add801 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -31,11 +31,12 @@ from paddle.distributed import alltoall, all_gather from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed import fleet -from paddle.autograd import PyLayer +from paddle.autograd import PyLayer, EagerPyLayer from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate from .utils import count_by_gate from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute from paddle import fluid +from paddle.fluid.framework import in_dygraph_mode def _local_scatter(inp, pos): @@ -63,17 +64,26 @@ def _local_gather(inp, pos, out_batch_size, maybe_overlap=True): def _all_gather(tensor, group=None, use_calc_stream=True): - """ - The main difference with paddle.distributed.all_gather: - no need to pass in tensor_list, the returned tensor is spliced - """ if group is not None and not group.is_member(): return - ring_id = 0 if group is None else group.id - nranks = paddle.distributed.collective._get_global_group( - ).nranks if group is None else group.nranks - return paddle._C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, - 'ring_id', ring_id, 'nranks', nranks) + + if in_dygraph_mode(): + group = paddle.distributed.collective._get_default_group( + ) if group is None else group + tensor_shape = list(tensor.shape) + tensor_shape[0] *= group.nranks + out = paddle.empty(tensor_shape, tensor.dtype) + + task = group.process_group.all_gather(tensor, out) + task.wait() + return out + else: + ring_id = 0 if group is None else group.id + nranks = paddle.distributed.collective._get_global_group( + ).nranks if group is None else group.nranks + return paddle._C_ops.c_allgather(tensor, 'use_calc_stream', + use_calc_stream, 'ring_id', ring_id, + 'nranks', nranks) class MoEScatter(PyLayer): @@ -122,6 +132,52 @@ class MoEScatter(PyLayer): return grad_in, None, None, None +class EagerMoEScatter(EagerPyLayer): + r""" + Scatter input samples from [batch x sequences] to contiguous alone experts. + If `world_size` is greater than 1, the samples will first be locally + scattered, and then exchanged across workers. + """ + + @staticmethod + def forward(ctx, + inp, + pos, + local_expert_count, + global_expert_count, + fwd_batch_size, + world_size, + group=None): + local_input_buf = _local_scatter(inp, pos) + if world_size > 1: + global_input_buf = global_scatter( + local_input_buf, + local_expert_count, + global_expert_count, + group=group) + else: + global_input_buf = local_input_buf + + ctx.moe_args = inp.shape[0], world_size, group + + variables = (pos, local_expert_count, global_expert_count) + ctx.save_for_backward(*variables) + return global_input_buf + + @staticmethod + def backward(ctx, grad): + (pos, local_expert_count, global_expert_count) = ctx.saved_tensor() + (inp_batch_size, world_size, group) = ctx.moe_args + + if world_size > 1: + local_grad_in = global_gather( + grad, local_expert_count, global_expert_count, group=group) + else: + local_grad_in = grad + grad_in = _local_gather(local_grad_in, pos, inp_batch_size) + return grad_in, None, None, None + + class MoEGather(PyLayer): r""" Gather output samples from contiguous alone experts back to [batch x @@ -169,6 +225,53 @@ class MoEGather(PyLayer): return global_grad_out_buf, None, None, None +class EagerMoEGather(EagerPyLayer): + r""" + Gather output samples from contiguous alone experts back to [batch x + sequences]. Works symmetrically with MoEScatter. + """ + + @staticmethod + def forward(ctx, + global_output_buf, + pos, + local_expert_count, + global_expert_count, + local_batch_size, + world_size, + group=None): + if world_size > 1: + local_output_buf = global_gather( + global_output_buf, + local_expert_count, + global_expert_count, + group=group) + else: + local_output_buf = global_output_buf + output = _local_gather( + local_output_buf, pos, local_batch_size, maybe_overlap=False) + + ctx.moe_args = (global_output_buf.shape[0], world_size, group) + variables = (pos, local_expert_count, global_expert_count) + ctx.save_for_backward(*variables) + return output + + @staticmethod + def backward(ctx, grad_out): + pos, local_expert_count, global_expert_count = ctx.saved_tensor() + fwd_batch_size, world_size, group = ctx.moe_args + grad_out_buf = _local_scatter(grad_out, pos) + if world_size > 1: + global_grad_out_buf = global_scatter( + grad_out_buf, + local_expert_count, + global_expert_count, + group=group) + else: + global_grad_out_buf = grad_out_buf + return global_grad_out_buf, None, None, None + + class AllGather(PyLayer): r""" A wrapper for the All-Gather function to support auto-differentiation. @@ -189,6 +292,26 @@ class AllGather(PyLayer): grad_out, axes=[0], starts=[rank * dim0], ends=[(rank + 1) * dim0]) +class EagerAllGather(EagerPyLayer): + r""" + A wrapper for the All-Gather function to support auto-differentiation. + """ + + @staticmethod + def forward(ctx, inp, rank, world_size, group): + tensor_list = [] + paddle.distributed.all_gather(tensor_list, inp, group=group) + output = paddle.concat(tensor_list, axis=0) + ctx.args = rank, inp.shape[0] + return output + + @staticmethod + def backward(ctx, grad_out): + rank, dim0 = ctx.args + return paddle.slice( + grad_out, axes=[0], starts=[rank * dim0], ends=[(rank + 1) * dim0]) + + class Slice(PyLayer): r""" A wrapper for the Slice function to support auto-differentiation. @@ -208,11 +331,29 @@ class Slice(PyLayer): @staticmethod def backward(ctx, grad_out): world_size, group = ctx.args - # tensor_list = [] - # paddle.distributed.all_gather(tensor_list, grad_out, group=group) - # grad_out = paddle.concat(tensor_list, axis=0) return _all_gather(grad_out, group=group) - # return grad_out + + +class EagerSlice(EagerPyLayer): + r""" + A wrapper for the Slice function to support auto-differentiation. + """ + + @staticmethod + def forward(ctx, inp, rank, world_size, group): + B = inp.shape[0] + local_batch_size = B // world_size + batch_start = local_batch_size * rank + batch_end = min(batch_start + local_batch_size, B) + inp = paddle.slice( + inp, axes=[0], starts=[batch_start], ends=[batch_end]) + ctx.args = world_size, group + return inp + + @staticmethod + def backward(ctx, grad_out): + world_size, group = ctx.args + return _all_gather(grad_out, group=group) def prepare_forward(gate, num_expert, world_size, moe_group): @@ -369,7 +510,10 @@ class MoELayer(nn.Layer): mp_rank = self.mp_group.rank mp_size = self.mp_group.nranks if mp_size > 1: - inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group) + if in_dygraph_mode(): + inp = EagerSlice.apply(inp, mp_rank, mp_size, self.mp_group) + else: + inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group) value, gate = self.gate(inp) ( @@ -390,9 +534,14 @@ class MoELayer(nn.Layer): temp_pos = pos assert topk == self.top_k - x = MoEScatter.apply(inp, temp_pos, local_expert_count, - global_expert_count, fwd_batch_size, - self.world_size, self.group) + if in_dygraph_mode(): + x = EagerMoEScatter.apply(inp, temp_pos, local_expert_count, + global_expert_count, fwd_batch_size, + self.world_size, self.group) + else: + x = MoEScatter.apply(inp, temp_pos, local_expert_count, + global_expert_count, fwd_batch_size, + self.world_size, self.group) d_model = self.d_model @@ -421,15 +570,23 @@ class MoELayer(nn.Layer): if len(gate.shape) == 2: out_batch_size *= gate.shape[1] - x = MoEGather.apply(x, pos, local_expert_count, global_expert_count, - out_batch_size, self.world_size, self.group) + if in_dygraph_mode(): + x = EagerMoEGather.apply(x, pos, local_expert_count, + global_expert_count, out_batch_size, + self.world_size, self.group) + else: + x = MoEGather.apply(x, pos, local_expert_count, global_expert_count, + out_batch_size, self.world_size, self.group) x = x.reshape([-1, self.top_k, d_model]) value = value.reshape([x.shape[0], 1, self.top_k]) x = paddle.bmm(value, x).reshape([-1, d_model]) if mp_size > 1: - x = AllGather.apply(x, mp_rank, mp_size, self.mp_group) + if in_dygraph_mode(): + x = EagerAllGather.apply(x, mp_rank, mp_size, self.mp_group) + else: + x = AllGather.apply(x, mp_rank, mp_size, self.mp_group) x = paddle.reshape_(x, origin_shape) diff --git a/python/paddle/incubate/distributed/models/moe/utils.py b/python/paddle/incubate/distributed/models/moe/utils.py index 25c76c9753..09a6b788b7 100644 --- a/python/paddle/incubate/distributed/models/moe/utils.py +++ b/python/paddle/incubate/distributed/models/moe/utils.py @@ -21,15 +21,24 @@ from paddle.distributed.models.moe.utils import _number_count, _limit_by_capacity, _prune_gate_by_capacity, _assign_pos import paddle +from paddle.fluid.framework import in_dygraph_mode def _alltoall(in_tensor_list, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return - ring_id = 0 if group is None else group.id - nranks = len(in_tensor_list) - return paddle._C_ops.alltoall(in_tensor_list, 'use_calc_stream', - use_calc_stream, 'ring_id', ring_id) + + if in_dygraph_mode(): + group = paddle.distributed.collective._get_default_group( + ) if group is None else group + out = paddle.empty(in_tensor_list.shape, in_tensor_list.dtype) + task = group.process_group.alltoall(in_tensor_list, out) + task.wait() + return out + else: + ring_id = 0 if group is None else group.id + return paddle._C_ops.alltoall(in_tensor_list, 'use_calc_stream', + use_calc_stream, 'ring_id', ring_id) def count_by_gate(gate, num_expert, world_size, require_pos=True, group=None): -- GitLab