From 77c010a01dfe8892f044517b3f94341c2c9ab086 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 13 Jul 2022 16:04:18 +0800 Subject: [PATCH] fix bug of pp (#44276) --- .../pp_utils/p2p_communication.py | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 17c7f5a9bb..6f917d9f89 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -54,25 +54,29 @@ class SendRecvMeta: def _recv_shape_dtype(self, group): # recv len(shape) dims = paddle.to_tensor([0]) - paddle.distributed.recv(dims, src=0, group=group) + src_rank = group.ranks[0] + + paddle.distributed.recv(dims, src=src_rank, group=group) dims = dims.item() # recv shape shape = paddle.to_tensor([0] * dims) - paddle.distributed.recv(shape, src=0, group=group) + paddle.distributed.recv(shape, src=src_rank, group=group) # recv dtype dtype = paddle.to_tensor([0]) - paddle.distributed.recv(dtype, src=0, group=group) + paddle.distributed.recv(dtype, src=src_rank, group=group) # recv stop_gradient stop_grad = paddle.to_tensor([0]) - paddle.distributed.recv(stop_grad, src=0, group=group) + paddle.distributed.recv(stop_grad, src=src_rank, group=group) return shape.numpy().tolist(), dtype.item(), stop_grad.item() def recv_meta(self, group): tensor_type = paddle.to_tensor([0]) - paddle.distributed.recv(tensor_type, src=0, group=group) + src_rank = group.ranks[0] + + paddle.distributed.recv(tensor_type, src=src_rank, group=group) tensor_type = tensor_type.item() if tensor_type == 0: @@ -83,7 +87,7 @@ class SendRecvMeta: elif tensor_type == 1: num = paddle.to_tensor([0]) - paddle.distributed.recv(num, src=0, group=group) + paddle.distributed.recv(num, src=src_rank, group=group) num = num.item() shapes = [] dtypes = [] @@ -101,34 +105,38 @@ class SendRecvMeta: def _send_dims_shape_dtype(self, tensor, group): # send len(shape) dims = paddle.to_tensor(len(tensor.shape)) - paddle.distributed.send(dims, dst=1, group=group) + dst_rank = group.ranks[1] + + paddle.distributed.send(dims, dst=dst_rank, group=group) # send shape shape = paddle.to_tensor(tensor.shape) - paddle.distributed.send(shape, dst=1, group=group) + paddle.distributed.send(shape, dst=dst_rank, group=group) # send dtype dtype = paddle.to_tensor(paddle_2_number(tensor.dtype)) - paddle.distributed.send(dtype, dst=1, group=group) + paddle.distributed.send(dtype, dst=dst_rank, group=group) # send trainable stop_grad = paddle.to_tensor(int(tensor.stop_gradient)) - paddle.distributed.send(stop_grad, dst=1, group=group) + paddle.distributed.send(stop_grad, dst=dst_rank, group=group) def send_meta(self, tensor, group): + dst_rank = group.ranks[1] + if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)): tensor_type = paddle.to_tensor([0]) # send tensor type - paddle.distributed.send(tensor_type, dst=1, group=group) + paddle.distributed.send(tensor_type, dst=dst_rank, group=group) self._send_dims_shape_dtype(tensor, group) elif isinstance(tensor, tuple): tensor_type = paddle.to_tensor([1]) # send tensor type - paddle.distributed.send(tensor_type, dst=1, group=group) + paddle.distributed.send(tensor_type, dst=dst_rank, group=group) nums = paddle.to_tensor(len(tensor)) - paddle.distributed.send(nums, dst=1, group=group) + paddle.distributed.send(nums, dst=dst_rank, group=group) for d in tensor: assert isinstance(d, (paddle.Tensor, core.eager.Tensor)) @@ -166,6 +174,7 @@ def send_partial(tensor, rank_id=0, group=None, use_calc_stream=True): + # dst: local rank in group if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id @@ -176,7 +185,7 @@ def send_partial(tensor, dst, 'num', nranks, 'id', rank_id) else: return paddle.distributed.send(tensor.detach(), - dst=dst, + dst=group.ranks[dst], group=group, use_calc_stream=use_calc_stream) @@ -187,6 +196,7 @@ def recv_partial(tensor, rank_id=0, group=None, use_calc_stream=True): + # src: local rank in group if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id @@ -198,7 +208,7 @@ def recv_partial(tensor, tensor.shape) else: paddle.distributed.recv(tensor.detach(), - src=src, + src=group.ranks[src], group=group, use_calc_stream=use_calc_stream) -- GitLab