未验证 提交 5ccc49e7 编写于 作者: H Haohongxiang 提交者: GitHub

support eager dygraph in moe_layer (#43168)

上级 0fbf815c
...@@ -405,9 +405,7 @@ def new_group(ranks=None, backend=None): ...@@ -405,9 +405,7 @@ def new_group(ranks=None, backend=None):
# TODO(shenliang03): This is a temporary solution to solve the problem of # TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by tcp # hang caused by tcp
tmp = paddle.to_tensor([1], dtype="int32") paddle.distributed.barrier(group=group)
paddle.distributed.all_reduce(tmp, group=group, use_calc_stream=True)
paddle.distributed.wait(tmp)
return group return group
if not backend: if not backend:
......
...@@ -19,6 +19,7 @@ from multiprocessing import Process # noqa: F401 ...@@ -19,6 +19,7 @@ from multiprocessing import Process # noqa: F401
from multiprocessing import Manager # noqa: F401 from multiprocessing import Manager # noqa: F401
import time import time
import sys import sys
import paddle
from paddle import compat as cpt from paddle import compat as cpt
...@@ -259,6 +260,8 @@ def init_parallel_env(): ...@@ -259,6 +260,8 @@ def init_parallel_env():
_set_group_map_by_name(_default_group_name, group) _set_group_map_by_name(_default_group_name, group)
_set_group_map(0, group) _set_group_map(0, group)
parallel_helper._set_parallel_ctx(True) parallel_helper._set_parallel_ctx(True)
paddle.distributed.barrier(group=group)
return group return group
node_num = set([i.split(":")[0] for i in parallel_env.trainer_endpoints]) node_num = set([i.split(":")[0] for i in parallel_env.trainer_endpoints])
......
...@@ -31,11 +31,12 @@ from paddle.distributed import alltoall, all_gather ...@@ -31,11 +31,12 @@ from paddle.distributed import alltoall, all_gather
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.autograd import PyLayer from paddle.autograd import PyLayer, EagerPyLayer
from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate
from .utils import count_by_gate from .utils import count_by_gate
from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute
from paddle import fluid from paddle import fluid
from paddle.fluid.framework import in_dygraph_mode
def _local_scatter(inp, pos): def _local_scatter(inp, pos):
...@@ -63,17 +64,26 @@ def _local_gather(inp, pos, out_batch_size, maybe_overlap=True): ...@@ -63,17 +64,26 @@ def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
def _all_gather(tensor, group=None, use_calc_stream=True): def _all_gather(tensor, group=None, use_calc_stream=True):
"""
The main difference with paddle.distributed.all_gather:
no need to pass in tensor_list, the returned tensor is spliced
"""
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id
nranks = paddle.distributed.collective._get_global_group( if in_dygraph_mode():
).nranks if group is None else group.nranks group = paddle.distributed.collective._get_default_group(
return paddle._C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, ) if group is None else group
'ring_id', ring_id, 'nranks', nranks) tensor_shape = list(tensor.shape)
tensor_shape[0] *= group.nranks
out = paddle.empty(tensor_shape, tensor.dtype)
task = group.process_group.all_gather(tensor, out)
task.wait()
return out
else:
ring_id = 0 if group is None else group.id
nranks = paddle.distributed.collective._get_global_group(
).nranks if group is None else group.nranks
return paddle._C_ops.c_allgather(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id,
'nranks', nranks)
class MoEScatter(PyLayer): class MoEScatter(PyLayer):
...@@ -122,6 +132,52 @@ class MoEScatter(PyLayer): ...@@ -122,6 +132,52 @@ class MoEScatter(PyLayer):
return grad_in, None, None, None return grad_in, None, None, None
class EagerMoEScatter(EagerPyLayer):
r"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
If `world_size` is greater than 1, the samples will first be locally
scattered, and then exchanged across workers.
"""
@staticmethod
def forward(ctx,
inp,
pos,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
group=None):
local_input_buf = _local_scatter(inp, pos)
if world_size > 1:
global_input_buf = global_scatter(
local_input_buf,
local_expert_count,
global_expert_count,
group=group)
else:
global_input_buf = local_input_buf
ctx.moe_args = inp.shape[0], world_size, group
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return global_input_buf
@staticmethod
def backward(ctx, grad):
(pos, local_expert_count, global_expert_count) = ctx.saved_tensor()
(inp_batch_size, world_size, group) = ctx.moe_args
if world_size > 1:
local_grad_in = global_gather(
grad, local_expert_count, global_expert_count, group=group)
else:
local_grad_in = grad
grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None
class MoEGather(PyLayer): class MoEGather(PyLayer):
r""" r"""
Gather output samples from contiguous alone experts back to [batch x Gather output samples from contiguous alone experts back to [batch x
...@@ -169,6 +225,53 @@ class MoEGather(PyLayer): ...@@ -169,6 +225,53 @@ class MoEGather(PyLayer):
return global_grad_out_buf, None, None, None return global_grad_out_buf, None, None, None
class EagerMoEGather(EagerPyLayer):
r"""
Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MoEScatter.
"""
@staticmethod
def forward(ctx,
global_output_buf,
pos,
local_expert_count,
global_expert_count,
local_batch_size,
world_size,
group=None):
if world_size > 1:
local_output_buf = global_gather(
global_output_buf,
local_expert_count,
global_expert_count,
group=group)
else:
local_output_buf = global_output_buf
output = _local_gather(
local_output_buf, pos, local_batch_size, maybe_overlap=False)
ctx.moe_args = (global_output_buf.shape[0], world_size, group)
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensor()
fwd_batch_size, world_size, group = ctx.moe_args
grad_out_buf = _local_scatter(grad_out, pos)
if world_size > 1:
global_grad_out_buf = global_scatter(
grad_out_buf,
local_expert_count,
global_expert_count,
group=group)
else:
global_grad_out_buf = grad_out_buf
return global_grad_out_buf, None, None, None
class AllGather(PyLayer): class AllGather(PyLayer):
r""" r"""
A wrapper for the All-Gather function to support auto-differentiation. A wrapper for the All-Gather function to support auto-differentiation.
...@@ -189,6 +292,26 @@ class AllGather(PyLayer): ...@@ -189,6 +292,26 @@ class AllGather(PyLayer):
grad_out, axes=[0], starts=[rank * dim0], ends=[(rank + 1) * dim0]) grad_out, axes=[0], starts=[rank * dim0], ends=[(rank + 1) * dim0])
class EagerAllGather(EagerPyLayer):
r"""
A wrapper for the All-Gather function to support auto-differentiation.
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
tensor_list = []
paddle.distributed.all_gather(tensor_list, inp, group=group)
output = paddle.concat(tensor_list, axis=0)
ctx.args = rank, inp.shape[0]
return output
@staticmethod
def backward(ctx, grad_out):
rank, dim0 = ctx.args
return paddle.slice(
grad_out, axes=[0], starts=[rank * dim0], ends=[(rank + 1) * dim0])
class Slice(PyLayer): class Slice(PyLayer):
r""" r"""
A wrapper for the Slice function to support auto-differentiation. A wrapper for the Slice function to support auto-differentiation.
...@@ -208,11 +331,29 @@ class Slice(PyLayer): ...@@ -208,11 +331,29 @@ class Slice(PyLayer):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
world_size, group = ctx.args world_size, group = ctx.args
# tensor_list = []
# paddle.distributed.all_gather(tensor_list, grad_out, group=group)
# grad_out = paddle.concat(tensor_list, axis=0)
return _all_gather(grad_out, group=group) return _all_gather(grad_out, group=group)
# return grad_out
class EagerSlice(EagerPyLayer):
r"""
A wrapper for the Slice function to support auto-differentiation.
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
B = inp.shape[0]
local_batch_size = B // world_size
batch_start = local_batch_size * rank
batch_end = min(batch_start + local_batch_size, B)
inp = paddle.slice(
inp, axes=[0], starts=[batch_start], ends=[batch_end])
ctx.args = world_size, group
return inp
@staticmethod
def backward(ctx, grad_out):
world_size, group = ctx.args
return _all_gather(grad_out, group=group)
def prepare_forward(gate, num_expert, world_size, moe_group): def prepare_forward(gate, num_expert, world_size, moe_group):
...@@ -369,7 +510,10 @@ class MoELayer(nn.Layer): ...@@ -369,7 +510,10 @@ class MoELayer(nn.Layer):
mp_rank = self.mp_group.rank mp_rank = self.mp_group.rank
mp_size = self.mp_group.nranks mp_size = self.mp_group.nranks
if mp_size > 1: if mp_size > 1:
inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group) if in_dygraph_mode():
inp = EagerSlice.apply(inp, mp_rank, mp_size, self.mp_group)
else:
inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group)
value, gate = self.gate(inp) value, gate = self.gate(inp)
( (
...@@ -390,9 +534,14 @@ class MoELayer(nn.Layer): ...@@ -390,9 +534,14 @@ class MoELayer(nn.Layer):
temp_pos = pos temp_pos = pos
assert topk == self.top_k assert topk == self.top_k
x = MoEScatter.apply(inp, temp_pos, local_expert_count, if in_dygraph_mode():
global_expert_count, fwd_batch_size, x = EagerMoEScatter.apply(inp, temp_pos, local_expert_count,
self.world_size, self.group) global_expert_count, fwd_batch_size,
self.world_size, self.group)
else:
x = MoEScatter.apply(inp, temp_pos, local_expert_count,
global_expert_count, fwd_batch_size,
self.world_size, self.group)
d_model = self.d_model d_model = self.d_model
...@@ -421,15 +570,23 @@ class MoELayer(nn.Layer): ...@@ -421,15 +570,23 @@ class MoELayer(nn.Layer):
if len(gate.shape) == 2: if len(gate.shape) == 2:
out_batch_size *= gate.shape[1] out_batch_size *= gate.shape[1]
x = MoEGather.apply(x, pos, local_expert_count, global_expert_count, if in_dygraph_mode():
out_batch_size, self.world_size, self.group) x = EagerMoEGather.apply(x, pos, local_expert_count,
global_expert_count, out_batch_size,
self.world_size, self.group)
else:
x = MoEGather.apply(x, pos, local_expert_count, global_expert_count,
out_batch_size, self.world_size, self.group)
x = x.reshape([-1, self.top_k, d_model]) x = x.reshape([-1, self.top_k, d_model])
value = value.reshape([x.shape[0], 1, self.top_k]) value = value.reshape([x.shape[0], 1, self.top_k])
x = paddle.bmm(value, x).reshape([-1, d_model]) x = paddle.bmm(value, x).reshape([-1, d_model])
if mp_size > 1: if mp_size > 1:
x = AllGather.apply(x, mp_rank, mp_size, self.mp_group) if in_dygraph_mode():
x = EagerAllGather.apply(x, mp_rank, mp_size, self.mp_group)
else:
x = AllGather.apply(x, mp_rank, mp_size, self.mp_group)
x = paddle.reshape_(x, origin_shape) x = paddle.reshape_(x, origin_shape)
......
...@@ -21,15 +21,24 @@ ...@@ -21,15 +21,24 @@
from paddle.distributed.models.moe.utils import _number_count, _limit_by_capacity, _prune_gate_by_capacity, _assign_pos from paddle.distributed.models.moe.utils import _number_count, _limit_by_capacity, _prune_gate_by_capacity, _assign_pos
import paddle import paddle
from paddle.fluid.framework import in_dygraph_mode
def _alltoall(in_tensor_list, group=None, use_calc_stream=True): def _alltoall(in_tensor_list, group=None, use_calc_stream=True):
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id
nranks = len(in_tensor_list) if in_dygraph_mode():
return paddle._C_ops.alltoall(in_tensor_list, 'use_calc_stream', group = paddle.distributed.collective._get_default_group(
use_calc_stream, 'ring_id', ring_id) ) if group is None else group
out = paddle.empty(in_tensor_list.shape, in_tensor_list.dtype)
task = group.process_group.alltoall(in_tensor_list, out)
task.wait()
return out
else:
ring_id = 0 if group is None else group.id
return paddle._C_ops.alltoall(in_tensor_list, 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id)
def count_by_gate(gate, num_expert, world_size, require_pos=True, group=None): def count_by_gate(gate, num_expert, world_size, require_pos=True, group=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册