未验证 提交 4cc3d9a2 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Fix bug of p2p for partial_send/recv (#34615)

* fix bug of p2p for partial

* fix error
上级 090c863a
......@@ -64,18 +64,6 @@ class PipelineParallel(MetaParallelBase):
logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg)
def _set_tensor_trainable(self, tensor):
if tensor is None:
return
if isinstance(tensor, tuple):
for t in tensor:
if is_float_tensor(t):
t.stop_gradient = False
else:
if is_float_tensor(tensor):
tensor.stop_gradient = False
def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
assert isinstance(optimizer, HybridParallelOptimizer), (
'optimizer should be HybridParallelOptimizer subclass.')
......@@ -117,7 +105,6 @@ class PipelineParallel(MetaParallelBase):
for step_id in range(startup_steps):
input_tensor = p2p.recv_forward()
self._set_tensor_trainable(input_tensor)
output_tensor = self._forward_step(input_tensor)
p2p.send_forward(output_tensor)
......@@ -131,7 +118,6 @@ class PipelineParallel(MetaParallelBase):
for i in range(steady_steps):
last_iter = (i == (steady_steps - 1))
self._set_tensor_trainable(input_tensor)
output_tensor = self._forward_step(input_tensor)
output_tensor_grad = p2p.send_forward_recv_backward(output_tensor)
......
......@@ -15,6 +15,8 @@
import paddle
from .utils import paddle_2_number, number_2_dtype
from ...utils.log_util import logger
import numpy as np
from paddle import _C_ops
_hcg = None
......@@ -40,6 +42,7 @@ class SendRecvMeta:
self.recv_shape_message = None
self.recv_dtype_message = None
self.recv_stop_gradient = None
self.has_send_meta = False
self.has_recv_meta = False
......@@ -57,7 +60,11 @@ class SendRecvMeta:
# recv dtype
dtype = paddle.to_tensor([0])
paddle.distributed.recv(dtype, src=0, group=group)
return shape.numpy().tolist(), dtype.item()
# recv stop_gradient
stop_grad = paddle.to_tensor([0])
paddle.distributed.recv(stop_grad, src=0, group=group)
return shape.numpy().tolist(), dtype.item(), stop_grad.item()
def recv_meta(self, group):
tensor_type = paddle.to_tensor([0])
......@@ -65,9 +72,10 @@ class SendRecvMeta:
tensor_type = tensor_type.item()
if tensor_type == 0:
shape, dtype = self._recv_shape_dtype(group)
shape, dtype, stop_grad = self._recv_shape_dtype(group)
self.recv_shape_message = shape
self.recv_dtype_message = dtype
self.recv_stop_gradient = bool(stop_grad)
elif tensor_type == 1:
num = paddle.to_tensor([0])
......@@ -75,13 +83,16 @@ class SendRecvMeta:
num = num.item()
shapes = []
dtypes = []
stop_grads = []
for i in range(num):
shape, dtype = self._recv_shape_dtype(group)
shape, dtype, stop_grad = self._recv_shape_dtype(group)
shapes.append(shape)
dtypes.append(dtype)
stop_grads.append(bool(stop_grad))
self.recv_shape_message = tuple(shapes)
self.recv_dtype_message = tuple(dtypes)
self.recv_stop_gradient = tuple(stop_grads)
def _send_dims_shape_dtype(self, tensor, group):
# send len(shape)
......@@ -96,6 +107,10 @@ class SendRecvMeta:
dtype = paddle.to_tensor(paddle_2_number(tensor.dtype))
paddle.distributed.send(dtype, dst=1, group=group)
# send trainable
stop_grad = paddle.to_tensor(int(tensor.stop_gradient))
paddle.distributed.send(stop_grad, dst=1, group=group)
def send_meta(self, tensor, group):
if isinstance(tensor, paddle.Tensor):
tensor_type = paddle.to_tensor([0])
......@@ -129,6 +144,12 @@ class SendRecvMeta:
_send_recv_meta = SendRecvMeta()
def _is_valid_send_recv_partial(tensor, mp_degree):
tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element"
return mp_degree > 1 and tensor_numel % mp_degree == 0
def send_partial(tensor,
dst=0,
nranks=1,
......@@ -138,9 +159,14 @@ def send_partial(tensor,
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_send(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
dst, 'num', nranks, 'id', rank_id)
if _is_valid_send_recv_partial(tensor, nranks):
return _C_ops.partial_send(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', dst, 'num',
nranks, 'id', rank_id)
else:
return paddle.distributed.send(
tensor, dst=dst, group=group, use_calc_stream=use_calc_stream)
def recv_partial(tensor,
......@@ -153,10 +179,14 @@ def recv_partial(tensor,
return
ring_id = 0 if group is None else group.id
paddle.fluid.core.ops.partial_recv(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
if _is_valid_send_recv_partial(tensor, nranks):
_C_ops.partial_recv(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', src, 'num', nranks,
'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
tensor.shape)
else:
paddle.distributed.recv(
tensor, src=src, group=group, use_calc_stream=use_calc_stream)
def allgather_partial(tensor,
......@@ -164,15 +194,15 @@ def allgather_partial(tensor,
rank_id=0,
group=None,
use_calc_stream=True):
if nranks == 1:
if not _is_valid_send_recv_partial(tensor, nranks):
return tensor
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_allgather_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id,
'nranks', nranks, 'rank', rank_id)
return _C_ops.partial_allgather_(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks,
'rank', rank_id)
def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
......@@ -184,6 +214,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
# send / recv message
recv_shape_msg = _send_recv_meta.recv_shape_message
recv_dtype_msg = _send_recv_meta.recv_dtype_message
recv_stop_gradient = _send_recv_meta.recv_stop_gradient
send_shape_msg = _send_recv_meta.send_shape_message
send_dtype_msg = _send_recv_meta.send_dtype_message
......@@ -196,13 +228,16 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(recv_shape_msg, tuple):
tensor_recv_prev = []
for idx, shape in enumerate(recv_shape_msg):
tensor_recv_prev.append(
paddle.empty(
shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx])))
tmp = paddle.empty(
shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx]))
tmp.stop_gradient = recv_stop_gradient[idx]
tensor_recv_prev.append(tmp)
tensor_recv_prev = tuple(tensor_recv_prev)
else:
tensor_recv_prev = paddle.empty(
shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg))
tensor_recv_prev.stop_gradient = recv_stop_gradient
if recv_next:
if isinstance(send_shape_msg, tuple):
......@@ -222,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for d in tensor_send_prev:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(
d,
d.detach(),
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
......@@ -231,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
send_partial(
tensor_send_prev,
tensor_send_prev.detach(),
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
......@@ -242,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev:
recv_partial(
d,
d.detach(),
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(
d,
d.detach(),
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
recv_partial(
tensor_recv_prev,
tensor_recv_prev.detach(),
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(
tensor_recv_prev,
tensor_recv_prev.detach(),
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
......@@ -274,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for d in tensor_send_next:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(
d,
d.detach(),
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
......@@ -283,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial(
tensor_send_next,
tensor_send_next.detach(),
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
......@@ -294,14 +329,14 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next:
recv_partial(
d,
d.detach(),
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)
allgather_partial(
d,
d.detach(),
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
......@@ -309,7 +344,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else:
recv_partial(
tensor_recv_next,
tensor_recv_next.detach(),
src=1,
nranks=mp_degree,
rank_id=mp_rank,
......@@ -317,7 +352,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
use_calc_stream=True)
allgather_partial(
tensor_recv_next,
tensor_recv_next.detach(),
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
......
......@@ -54,13 +54,17 @@ class EmbeddingNet(Layer):
attention_mask = paddle.tensor.triu(
(paddle.ones(
(length, length), dtype="float32") * -1e9), 1)
attention_mask.stop_gradient = True
no_used = paddle.ones((3, 3), dtype="int32")
w_emb = self.word_embeddings(x)
p_emb = self.position_embeddings(x)
w_emb = w_emb + p_emb
attention_mask.stop_gradient = True
no_used.stop_gradient = True
# need to fix bug of backward()
return w_emb, attention_mask
return w_emb, attention_mask, no_used, p_emb
class TransformerNet(Layer):
......@@ -99,12 +103,12 @@ class EmbeddingPipe(EmbeddingNet):
class TransformerNetPipe(TransformerNet):
def forward(self, args):
x, mask = args[0], args[1]
x, mask, no_used, p_emb = args[0], args[1], args[2], args[3]
output = super().forward(x, mask)
output = output
output = output + p_emb
mask.stop_gradient = True
return output, mask
return output, mask, no_used, p_emb
class CriterionPipe(Layer):
......@@ -175,6 +179,8 @@ class TestDistPPTraning(unittest.TestCase):
loss = model.train_batch([x, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss
print("loss: ", loss)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册