未验证 提交 4cc3d9a2 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Fix bug of p2p for partial_send/recv (#34615)

* fix bug of p2p for partial

* fix error
上级 090c863a
...@@ -64,18 +64,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -64,18 +64,6 @@ class PipelineParallel(MetaParallelBase):
logger.info("start broadcast dp parameters") logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg) broadcast_dp_parameters(self._layers, self._hcg)
def _set_tensor_trainable(self, tensor):
if tensor is None:
return
if isinstance(tensor, tuple):
for t in tensor:
if is_float_tensor(t):
t.stop_gradient = False
else:
if is_float_tensor(tensor):
tensor.stop_gradient = False
def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
assert isinstance(optimizer, HybridParallelOptimizer), ( assert isinstance(optimizer, HybridParallelOptimizer), (
'optimizer should be HybridParallelOptimizer subclass.') 'optimizer should be HybridParallelOptimizer subclass.')
...@@ -117,7 +105,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -117,7 +105,6 @@ class PipelineParallel(MetaParallelBase):
for step_id in range(startup_steps): for step_id in range(startup_steps):
input_tensor = p2p.recv_forward() input_tensor = p2p.recv_forward()
self._set_tensor_trainable(input_tensor)
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor)
p2p.send_forward(output_tensor) p2p.send_forward(output_tensor)
...@@ -131,7 +118,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -131,7 +118,6 @@ class PipelineParallel(MetaParallelBase):
for i in range(steady_steps): for i in range(steady_steps):
last_iter = (i == (steady_steps - 1)) last_iter = (i == (steady_steps - 1))
self._set_tensor_trainable(input_tensor)
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor)
output_tensor_grad = p2p.send_forward_recv_backward(output_tensor) output_tensor_grad = p2p.send_forward_recv_backward(output_tensor)
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import paddle import paddle
from .utils import paddle_2_number, number_2_dtype from .utils import paddle_2_number, number_2_dtype
from ...utils.log_util import logger from ...utils.log_util import logger
import numpy as np
from paddle import _C_ops
_hcg = None _hcg = None
...@@ -40,6 +42,7 @@ class SendRecvMeta: ...@@ -40,6 +42,7 @@ class SendRecvMeta:
self.recv_shape_message = None self.recv_shape_message = None
self.recv_dtype_message = None self.recv_dtype_message = None
self.recv_stop_gradient = None
self.has_send_meta = False self.has_send_meta = False
self.has_recv_meta = False self.has_recv_meta = False
...@@ -57,7 +60,11 @@ class SendRecvMeta: ...@@ -57,7 +60,11 @@ class SendRecvMeta:
# recv dtype # recv dtype
dtype = paddle.to_tensor([0]) dtype = paddle.to_tensor([0])
paddle.distributed.recv(dtype, src=0, group=group) paddle.distributed.recv(dtype, src=0, group=group)
return shape.numpy().tolist(), dtype.item()
# recv stop_gradient
stop_grad = paddle.to_tensor([0])
paddle.distributed.recv(stop_grad, src=0, group=group)
return shape.numpy().tolist(), dtype.item(), stop_grad.item()
def recv_meta(self, group): def recv_meta(self, group):
tensor_type = paddle.to_tensor([0]) tensor_type = paddle.to_tensor([0])
...@@ -65,9 +72,10 @@ class SendRecvMeta: ...@@ -65,9 +72,10 @@ class SendRecvMeta:
tensor_type = tensor_type.item() tensor_type = tensor_type.item()
if tensor_type == 0: if tensor_type == 0:
shape, dtype = self._recv_shape_dtype(group) shape, dtype, stop_grad = self._recv_shape_dtype(group)
self.recv_shape_message = shape self.recv_shape_message = shape
self.recv_dtype_message = dtype self.recv_dtype_message = dtype
self.recv_stop_gradient = bool(stop_grad)
elif tensor_type == 1: elif tensor_type == 1:
num = paddle.to_tensor([0]) num = paddle.to_tensor([0])
...@@ -75,13 +83,16 @@ class SendRecvMeta: ...@@ -75,13 +83,16 @@ class SendRecvMeta:
num = num.item() num = num.item()
shapes = [] shapes = []
dtypes = [] dtypes = []
stop_grads = []
for i in range(num): for i in range(num):
shape, dtype = self._recv_shape_dtype(group) shape, dtype, stop_grad = self._recv_shape_dtype(group)
shapes.append(shape) shapes.append(shape)
dtypes.append(dtype) dtypes.append(dtype)
stop_grads.append(bool(stop_grad))
self.recv_shape_message = tuple(shapes) self.recv_shape_message = tuple(shapes)
self.recv_dtype_message = tuple(dtypes) self.recv_dtype_message = tuple(dtypes)
self.recv_stop_gradient = tuple(stop_grads)
def _send_dims_shape_dtype(self, tensor, group): def _send_dims_shape_dtype(self, tensor, group):
# send len(shape) # send len(shape)
...@@ -96,6 +107,10 @@ class SendRecvMeta: ...@@ -96,6 +107,10 @@ class SendRecvMeta:
dtype = paddle.to_tensor(paddle_2_number(tensor.dtype)) dtype = paddle.to_tensor(paddle_2_number(tensor.dtype))
paddle.distributed.send(dtype, dst=1, group=group) paddle.distributed.send(dtype, dst=1, group=group)
# send trainable
stop_grad = paddle.to_tensor(int(tensor.stop_gradient))
paddle.distributed.send(stop_grad, dst=1, group=group)
def send_meta(self, tensor, group): def send_meta(self, tensor, group):
if isinstance(tensor, paddle.Tensor): if isinstance(tensor, paddle.Tensor):
tensor_type = paddle.to_tensor([0]) tensor_type = paddle.to_tensor([0])
...@@ -129,6 +144,12 @@ class SendRecvMeta: ...@@ -129,6 +144,12 @@ class SendRecvMeta:
_send_recv_meta = SendRecvMeta() _send_recv_meta = SendRecvMeta()
def _is_valid_send_recv_partial(tensor, mp_degree):
tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element"
return mp_degree > 1 and tensor_numel % mp_degree == 0
def send_partial(tensor, def send_partial(tensor,
dst=0, dst=0,
nranks=1, nranks=1,
...@@ -138,9 +159,14 @@ def send_partial(tensor, ...@@ -138,9 +159,14 @@ def send_partial(tensor,
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_send(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', if _is_valid_send_recv_partial(tensor, nranks):
dst, 'num', nranks, 'id', rank_id) return _C_ops.partial_send(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', dst, 'num',
nranks, 'id', rank_id)
else:
return paddle.distributed.send(
tensor, dst=dst, group=group, use_calc_stream=use_calc_stream)
def recv_partial(tensor, def recv_partial(tensor,
...@@ -153,10 +179,14 @@ def recv_partial(tensor, ...@@ -153,10 +179,14 @@ def recv_partial(tensor,
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
paddle.fluid.core.ops.partial_recv( if _is_valid_send_recv_partial(tensor, nranks):
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', _C_ops.partial_recv(tensor, 'use_calc_stream', use_calc_stream,
src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape', 'ring_id', ring_id, 'peer', src, 'num', nranks,
tensor.shape) 'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
tensor.shape)
else:
paddle.distributed.recv(
tensor, src=src, group=group, use_calc_stream=use_calc_stream)
def allgather_partial(tensor, def allgather_partial(tensor,
...@@ -164,15 +194,15 @@ def allgather_partial(tensor, ...@@ -164,15 +194,15 @@ def allgather_partial(tensor,
rank_id=0, rank_id=0,
group=None, group=None,
use_calc_stream=True): use_calc_stream=True):
if nranks == 1: if not _is_valid_send_recv_partial(tensor, nranks):
return tensor return tensor
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_allgather_( return _C_ops.partial_allgather_(tensor, 'use_calc_stream', use_calc_stream,
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'ring_id', ring_id, 'nranks', nranks,
'nranks', nranks, 'rank', rank_id) 'rank', rank_id)
def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...@@ -184,6 +214,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -184,6 +214,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
# send / recv message # send / recv message
recv_shape_msg = _send_recv_meta.recv_shape_message recv_shape_msg = _send_recv_meta.recv_shape_message
recv_dtype_msg = _send_recv_meta.recv_dtype_message recv_dtype_msg = _send_recv_meta.recv_dtype_message
recv_stop_gradient = _send_recv_meta.recv_stop_gradient
send_shape_msg = _send_recv_meta.send_shape_message send_shape_msg = _send_recv_meta.send_shape_message
send_dtype_msg = _send_recv_meta.send_dtype_message send_dtype_msg = _send_recv_meta.send_dtype_message
...@@ -196,13 +228,16 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -196,13 +228,16 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(recv_shape_msg, tuple): if isinstance(recv_shape_msg, tuple):
tensor_recv_prev = [] tensor_recv_prev = []
for idx, shape in enumerate(recv_shape_msg): for idx, shape in enumerate(recv_shape_msg):
tensor_recv_prev.append( tmp = paddle.empty(
paddle.empty( shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx]))
shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx]))) tmp.stop_gradient = recv_stop_gradient[idx]
tensor_recv_prev.append(tmp)
tensor_recv_prev = tuple(tensor_recv_prev) tensor_recv_prev = tuple(tensor_recv_prev)
else: else:
tensor_recv_prev = paddle.empty( tensor_recv_prev = paddle.empty(
shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg)) shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg))
tensor_recv_prev.stop_gradient = recv_stop_gradient
if recv_next: if recv_next:
if isinstance(send_shape_msg, tuple): if isinstance(send_shape_msg, tuple):
...@@ -222,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -222,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for d in tensor_send_prev: for d in tensor_send_prev:
paddle.distributed.wait(d, use_calc_stream=True) paddle.distributed.wait(d, use_calc_stream=True)
send_partial( send_partial(
d, d.detach(),
dst=0, dst=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
...@@ -231,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -231,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else: else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
send_partial( send_partial(
tensor_send_prev, tensor_send_prev.detach(),
dst=0, dst=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
...@@ -242,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -242,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(tensor_recv_prev, tuple): if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev: for d in tensor_recv_prev:
recv_partial( recv_partial(
d, d.detach(),
src=0, src=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_prev_group, group=_hcg.recv_prev_group,
use_calc_stream=True) use_calc_stream=True)
allgather_partial( allgather_partial(
d, d.detach(),
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
use_calc_stream=True) use_calc_stream=True)
else: else:
recv_partial( recv_partial(
tensor_recv_prev, tensor_recv_prev.detach(),
src=0, src=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_prev_group, group=_hcg.recv_prev_group,
use_calc_stream=True) use_calc_stream=True)
allgather_partial( allgather_partial(
tensor_recv_prev, tensor_recv_prev.detach(),
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
...@@ -274,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -274,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for d in tensor_send_next: for d in tensor_send_next:
paddle.distributed.wait(d, use_calc_stream=True) paddle.distributed.wait(d, use_calc_stream=True)
send_partial( send_partial(
d, d.detach(),
dst=1, dst=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
...@@ -283,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -283,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else: else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True) paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial( send_partial(
tensor_send_next, tensor_send_next.detach(),
dst=1, dst=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
...@@ -294,14 +329,14 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -294,14 +329,14 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(tensor_recv_next, tuple): if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next: for d in tensor_recv_next:
recv_partial( recv_partial(
d, d.detach(),
src=1, src=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_next_group, group=_hcg.recv_next_group,
use_calc_stream=True) use_calc_stream=True)
allgather_partial( allgather_partial(
d, d.detach(),
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
...@@ -309,7 +344,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -309,7 +344,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else: else:
recv_partial( recv_partial(
tensor_recv_next, tensor_recv_next.detach(),
src=1, src=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
...@@ -317,7 +352,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -317,7 +352,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
use_calc_stream=True) use_calc_stream=True)
allgather_partial( allgather_partial(
tensor_recv_next, tensor_recv_next.detach(),
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
......
...@@ -54,13 +54,17 @@ class EmbeddingNet(Layer): ...@@ -54,13 +54,17 @@ class EmbeddingNet(Layer):
attention_mask = paddle.tensor.triu( attention_mask = paddle.tensor.triu(
(paddle.ones( (paddle.ones(
(length, length), dtype="float32") * -1e9), 1) (length, length), dtype="float32") * -1e9), 1)
attention_mask.stop_gradient = True
no_used = paddle.ones((3, 3), dtype="int32")
w_emb = self.word_embeddings(x) w_emb = self.word_embeddings(x)
p_emb = self.position_embeddings(x) p_emb = self.position_embeddings(x)
w_emb = w_emb + p_emb w_emb = w_emb + p_emb
attention_mask.stop_gradient = True
no_used.stop_gradient = True
# need to fix bug of backward() # need to fix bug of backward()
return w_emb, attention_mask return w_emb, attention_mask, no_used, p_emb
class TransformerNet(Layer): class TransformerNet(Layer):
...@@ -99,12 +103,12 @@ class EmbeddingPipe(EmbeddingNet): ...@@ -99,12 +103,12 @@ class EmbeddingPipe(EmbeddingNet):
class TransformerNetPipe(TransformerNet): class TransformerNetPipe(TransformerNet):
def forward(self, args): def forward(self, args):
x, mask = args[0], args[1] x, mask, no_used, p_emb = args[0], args[1], args[2], args[3]
output = super().forward(x, mask) output = super().forward(x, mask)
output = output output = output + p_emb
mask.stop_gradient = True mask.stop_gradient = True
return output, mask return output, mask, no_used, p_emb
class CriterionPipe(Layer): class CriterionPipe(Layer):
...@@ -175,6 +179,8 @@ class TestDistPPTraning(unittest.TestCase): ...@@ -175,6 +179,8 @@ class TestDistPPTraning(unittest.TestCase):
loss = model.train_batch([x, x], optimizer, scheduler) loss = model.train_batch([x, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss # TODO(shenliang03) add utest for loss
print("loss: ", loss)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册