未验证 提交 77c010a0 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of pp (#44276)

上级 b1aa693e
...@@ -54,25 +54,29 @@ class SendRecvMeta: ...@@ -54,25 +54,29 @@ class SendRecvMeta:
def _recv_shape_dtype(self, group): def _recv_shape_dtype(self, group):
# recv len(shape) # recv len(shape)
dims = paddle.to_tensor([0]) dims = paddle.to_tensor([0])
paddle.distributed.recv(dims, src=0, group=group) src_rank = group.ranks[0]
paddle.distributed.recv(dims, src=src_rank, group=group)
dims = dims.item() dims = dims.item()
# recv shape # recv shape
shape = paddle.to_tensor([0] * dims) shape = paddle.to_tensor([0] * dims)
paddle.distributed.recv(shape, src=0, group=group) paddle.distributed.recv(shape, src=src_rank, group=group)
# recv dtype # recv dtype
dtype = paddle.to_tensor([0]) dtype = paddle.to_tensor([0])
paddle.distributed.recv(dtype, src=0, group=group) paddle.distributed.recv(dtype, src=src_rank, group=group)
# recv stop_gradient # recv stop_gradient
stop_grad = paddle.to_tensor([0]) stop_grad = paddle.to_tensor([0])
paddle.distributed.recv(stop_grad, src=0, group=group) paddle.distributed.recv(stop_grad, src=src_rank, group=group)
return shape.numpy().tolist(), dtype.item(), stop_grad.item() return shape.numpy().tolist(), dtype.item(), stop_grad.item()
def recv_meta(self, group): def recv_meta(self, group):
tensor_type = paddle.to_tensor([0]) tensor_type = paddle.to_tensor([0])
paddle.distributed.recv(tensor_type, src=0, group=group) src_rank = group.ranks[0]
paddle.distributed.recv(tensor_type, src=src_rank, group=group)
tensor_type = tensor_type.item() tensor_type = tensor_type.item()
if tensor_type == 0: if tensor_type == 0:
...@@ -83,7 +87,7 @@ class SendRecvMeta: ...@@ -83,7 +87,7 @@ class SendRecvMeta:
elif tensor_type == 1: elif tensor_type == 1:
num = paddle.to_tensor([0]) num = paddle.to_tensor([0])
paddle.distributed.recv(num, src=0, group=group) paddle.distributed.recv(num, src=src_rank, group=group)
num = num.item() num = num.item()
shapes = [] shapes = []
dtypes = [] dtypes = []
...@@ -101,34 +105,38 @@ class SendRecvMeta: ...@@ -101,34 +105,38 @@ class SendRecvMeta:
def _send_dims_shape_dtype(self, tensor, group): def _send_dims_shape_dtype(self, tensor, group):
# send len(shape) # send len(shape)
dims = paddle.to_tensor(len(tensor.shape)) dims = paddle.to_tensor(len(tensor.shape))
paddle.distributed.send(dims, dst=1, group=group) dst_rank = group.ranks[1]
paddle.distributed.send(dims, dst=dst_rank, group=group)
# send shape # send shape
shape = paddle.to_tensor(tensor.shape) shape = paddle.to_tensor(tensor.shape)
paddle.distributed.send(shape, dst=1, group=group) paddle.distributed.send(shape, dst=dst_rank, group=group)
# send dtype # send dtype
dtype = paddle.to_tensor(paddle_2_number(tensor.dtype)) dtype = paddle.to_tensor(paddle_2_number(tensor.dtype))
paddle.distributed.send(dtype, dst=1, group=group) paddle.distributed.send(dtype, dst=dst_rank, group=group)
# send trainable # send trainable
stop_grad = paddle.to_tensor(int(tensor.stop_gradient)) stop_grad = paddle.to_tensor(int(tensor.stop_gradient))
paddle.distributed.send(stop_grad, dst=1, group=group) paddle.distributed.send(stop_grad, dst=dst_rank, group=group)
def send_meta(self, tensor, group): def send_meta(self, tensor, group):
dst_rank = group.ranks[1]
if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)): if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)):
tensor_type = paddle.to_tensor([0]) tensor_type = paddle.to_tensor([0])
# send tensor type # send tensor type
paddle.distributed.send(tensor_type, dst=1, group=group) paddle.distributed.send(tensor_type, dst=dst_rank, group=group)
self._send_dims_shape_dtype(tensor, group) self._send_dims_shape_dtype(tensor, group)
elif isinstance(tensor, tuple): elif isinstance(tensor, tuple):
tensor_type = paddle.to_tensor([1]) tensor_type = paddle.to_tensor([1])
# send tensor type # send tensor type
paddle.distributed.send(tensor_type, dst=1, group=group) paddle.distributed.send(tensor_type, dst=dst_rank, group=group)
nums = paddle.to_tensor(len(tensor)) nums = paddle.to_tensor(len(tensor))
paddle.distributed.send(nums, dst=1, group=group) paddle.distributed.send(nums, dst=dst_rank, group=group)
for d in tensor: for d in tensor:
assert isinstance(d, (paddle.Tensor, core.eager.Tensor)) assert isinstance(d, (paddle.Tensor, core.eager.Tensor))
...@@ -166,6 +174,7 @@ def send_partial(tensor, ...@@ -166,6 +174,7 @@ def send_partial(tensor,
rank_id=0, rank_id=0,
group=None, group=None,
use_calc_stream=True): use_calc_stream=True):
# dst: local rank in group
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
...@@ -176,7 +185,7 @@ def send_partial(tensor, ...@@ -176,7 +185,7 @@ def send_partial(tensor,
dst, 'num', nranks, 'id', rank_id) dst, 'num', nranks, 'id', rank_id)
else: else:
return paddle.distributed.send(tensor.detach(), return paddle.distributed.send(tensor.detach(),
dst=dst, dst=group.ranks[dst],
group=group, group=group,
use_calc_stream=use_calc_stream) use_calc_stream=use_calc_stream)
...@@ -187,6 +196,7 @@ def recv_partial(tensor, ...@@ -187,6 +196,7 @@ def recv_partial(tensor,
rank_id=0, rank_id=0,
group=None, group=None,
use_calc_stream=True): use_calc_stream=True):
# src: local rank in group
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
...@@ -198,7 +208,7 @@ def recv_partial(tensor, ...@@ -198,7 +208,7 @@ def recv_partial(tensor,
tensor.shape) tensor.shape)
else: else:
paddle.distributed.recv(tensor.detach(), paddle.distributed.recv(tensor.detach(),
src=src, src=group.ranks[src],
group=group, group=group,
use_calc_stream=use_calc_stream) use_calc_stream=use_calc_stream)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册