未验证 提交 5ccc49e7 编写于 作者: H Haohongxiang 提交者: GitHub

support eager dygraph in moe_layer (#43168)

上级 0fbf815c
......@@ -405,9 +405,7 @@ def new_group(ranks=None, backend=None):
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by tcp
tmp = paddle.to_tensor([1], dtype="int32")
paddle.distributed.all_reduce(tmp, group=group, use_calc_stream=True)
paddle.distributed.wait(tmp)
paddle.distributed.barrier(group=group)
return group
if not backend:
......
......@@ -19,6 +19,7 @@ from multiprocessing import Process # noqa: F401
from multiprocessing import Manager # noqa: F401
import time
import sys
import paddle
from paddle import compat as cpt
......@@ -259,6 +260,8 @@ def init_parallel_env():
_set_group_map_by_name(_default_group_name, group)
_set_group_map(0, group)
parallel_helper._set_parallel_ctx(True)
paddle.distributed.barrier(group=group)
return group
node_num = set([i.split(":")[0] for i in parallel_env.trainer_endpoints])
......
......@@ -31,11 +31,12 @@ from paddle.distributed import alltoall, all_gather
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed import fleet
from paddle.autograd import PyLayer
from paddle.autograd import PyLayer, EagerPyLayer
from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate
from .utils import count_by_gate
from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute
from paddle import fluid
from paddle.fluid.framework import in_dygraph_mode
def _local_scatter(inp, pos):
......@@ -63,17 +64,26 @@ def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
def _all_gather(tensor, group=None, use_calc_stream=True):
"""
The main difference with paddle.distributed.all_gather:
no need to pass in tensor_list, the returned tensor is spliced
"""
if group is not None and not group.is_member():
return
if in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
tensor_shape = list(tensor.shape)
tensor_shape[0] *= group.nranks
out = paddle.empty(tensor_shape, tensor.dtype)
task = group.process_group.all_gather(tensor, out)
task.wait()
return out
else:
ring_id = 0 if group is None else group.id
nranks = paddle.distributed.collective._get_global_group(
).nranks if group is None else group.nranks
return paddle._C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks)
return paddle._C_ops.c_allgather(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id,
'nranks', nranks)
class MoEScatter(PyLayer):
......@@ -122,6 +132,52 @@ class MoEScatter(PyLayer):
return grad_in, None, None, None
class EagerMoEScatter(EagerPyLayer):
r"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
If `world_size` is greater than 1, the samples will first be locally
scattered, and then exchanged across workers.
"""
@staticmethod
def forward(ctx,
inp,
pos,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
group=None):
local_input_buf = _local_scatter(inp, pos)
if world_size > 1:
global_input_buf = global_scatter(
local_input_buf,
local_expert_count,
global_expert_count,
group=group)
else:
global_input_buf = local_input_buf
ctx.moe_args = inp.shape[0], world_size, group
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return global_input_buf
@staticmethod
def backward(ctx, grad):
(pos, local_expert_count, global_expert_count) = ctx.saved_tensor()
(inp_batch_size, world_size, group) = ctx.moe_args
if world_size > 1:
local_grad_in = global_gather(
grad, local_expert_count, global_expert_count, group=group)
else:
local_grad_in = grad
grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None
class MoEGather(PyLayer):
r"""
Gather output samples from contiguous alone experts back to [batch x
......@@ -169,6 +225,53 @@ class MoEGather(PyLayer):
return global_grad_out_buf, None, None, None
class EagerMoEGather(EagerPyLayer):
r"""
Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MoEScatter.
"""
@staticmethod
def forward(ctx,
global_output_buf,
pos,
local_expert_count,
global_expert_count,
local_batch_size,
world_size,
group=None):
if world_size > 1:
local_output_buf = global_gather(
global_output_buf,
local_expert_count,
global_expert_count,
group=group)
else:
local_output_buf = global_output_buf
output = _local_gather(
local_output_buf, pos, local_batch_size, maybe_overlap=False)
ctx.moe_args = (global_output_buf.shape[0], world_size, group)
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensor()
fwd_batch_size, world_size, group = ctx.moe_args
grad_out_buf = _local_scatter(grad_out, pos)
if world_size > 1:
global_grad_out_buf = global_scatter(
grad_out_buf,
local_expert_count,
global_expert_count,
group=group)
else:
global_grad_out_buf = grad_out_buf
return global_grad_out_buf, None, None, None
class AllGather(PyLayer):
r"""
A wrapper for the All-Gather function to support auto-differentiation.
......@@ -189,6 +292,26 @@ class AllGather(PyLayer):
grad_out, axes=[0], starts=[rank * dim0], ends=[(rank + 1) * dim0])
class EagerAllGather(EagerPyLayer):
r"""
A wrapper for the All-Gather function to support auto-differentiation.
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
tensor_list = []
paddle.distributed.all_gather(tensor_list, inp, group=group)
output = paddle.concat(tensor_list, axis=0)
ctx.args = rank, inp.shape[0]
return output
@staticmethod
def backward(ctx, grad_out):
rank, dim0 = ctx.args
return paddle.slice(
grad_out, axes=[0], starts=[rank * dim0], ends=[(rank + 1) * dim0])
class Slice(PyLayer):
r"""
A wrapper for the Slice function to support auto-differentiation.
......@@ -208,11 +331,29 @@ class Slice(PyLayer):
@staticmethod
def backward(ctx, grad_out):
world_size, group = ctx.args
# tensor_list = []
# paddle.distributed.all_gather(tensor_list, grad_out, group=group)
# grad_out = paddle.concat(tensor_list, axis=0)
return _all_gather(grad_out, group=group)
# return grad_out
class EagerSlice(EagerPyLayer):
r"""
A wrapper for the Slice function to support auto-differentiation.
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
B = inp.shape[0]
local_batch_size = B // world_size
batch_start = local_batch_size * rank
batch_end = min(batch_start + local_batch_size, B)
inp = paddle.slice(
inp, axes=[0], starts=[batch_start], ends=[batch_end])
ctx.args = world_size, group
return inp
@staticmethod
def backward(ctx, grad_out):
world_size, group = ctx.args
return _all_gather(grad_out, group=group)
def prepare_forward(gate, num_expert, world_size, moe_group):
......@@ -369,6 +510,9 @@ class MoELayer(nn.Layer):
mp_rank = self.mp_group.rank
mp_size = self.mp_group.nranks
if mp_size > 1:
if in_dygraph_mode():
inp = EagerSlice.apply(inp, mp_rank, mp_size, self.mp_group)
else:
inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group)
value, gate = self.gate(inp)
......@@ -390,6 +534,11 @@ class MoELayer(nn.Layer):
temp_pos = pos
assert topk == self.top_k
if in_dygraph_mode():
x = EagerMoEScatter.apply(inp, temp_pos, local_expert_count,
global_expert_count, fwd_batch_size,
self.world_size, self.group)
else:
x = MoEScatter.apply(inp, temp_pos, local_expert_count,
global_expert_count, fwd_batch_size,
self.world_size, self.group)
......@@ -421,6 +570,11 @@ class MoELayer(nn.Layer):
if len(gate.shape) == 2:
out_batch_size *= gate.shape[1]
if in_dygraph_mode():
x = EagerMoEGather.apply(x, pos, local_expert_count,
global_expert_count, out_batch_size,
self.world_size, self.group)
else:
x = MoEGather.apply(x, pos, local_expert_count, global_expert_count,
out_batch_size, self.world_size, self.group)
......@@ -429,6 +583,9 @@ class MoELayer(nn.Layer):
x = paddle.bmm(value, x).reshape([-1, d_model])
if mp_size > 1:
if in_dygraph_mode():
x = EagerAllGather.apply(x, mp_rank, mp_size, self.mp_group)
else:
x = AllGather.apply(x, mp_rank, mp_size, self.mp_group)
x = paddle.reshape_(x, origin_shape)
......
......@@ -21,13 +21,22 @@
from paddle.distributed.models.moe.utils import _number_count, _limit_by_capacity, _prune_gate_by_capacity, _assign_pos
import paddle
from paddle.fluid.framework import in_dygraph_mode
def _alltoall(in_tensor_list, group=None, use_calc_stream=True):
if group is not None and not group.is_member():
return
if in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
out = paddle.empty(in_tensor_list.shape, in_tensor_list.dtype)
task = group.process_group.alltoall(in_tensor_list, out)
task.wait()
return out
else:
ring_id = 0 if group is None else group.id
nranks = len(in_tensor_list)
return paddle._C_ops.alltoall(in_tensor_list, 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册