未验证 提交 77c010a0 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of pp (#44276)

上级 b1aa693e
......@@ -54,25 +54,29 @@ class SendRecvMeta:
def _recv_shape_dtype(self, group):
# recv len(shape)
dims = paddle.to_tensor([0])
paddle.distributed.recv(dims, src=0, group=group)
src_rank = group.ranks[0]
paddle.distributed.recv(dims, src=src_rank, group=group)
dims = dims.item()
# recv shape
shape = paddle.to_tensor([0] * dims)
paddle.distributed.recv(shape, src=0, group=group)
paddle.distributed.recv(shape, src=src_rank, group=group)
# recv dtype
dtype = paddle.to_tensor([0])
paddle.distributed.recv(dtype, src=0, group=group)
paddle.distributed.recv(dtype, src=src_rank, group=group)
# recv stop_gradient
stop_grad = paddle.to_tensor([0])
paddle.distributed.recv(stop_grad, src=0, group=group)
paddle.distributed.recv(stop_grad, src=src_rank, group=group)
return shape.numpy().tolist(), dtype.item(), stop_grad.item()
def recv_meta(self, group):
tensor_type = paddle.to_tensor([0])
paddle.distributed.recv(tensor_type, src=0, group=group)
src_rank = group.ranks[0]
paddle.distributed.recv(tensor_type, src=src_rank, group=group)
tensor_type = tensor_type.item()
if tensor_type == 0:
......@@ -83,7 +87,7 @@ class SendRecvMeta:
elif tensor_type == 1:
num = paddle.to_tensor([0])
paddle.distributed.recv(num, src=0, group=group)
paddle.distributed.recv(num, src=src_rank, group=group)
num = num.item()
shapes = []
dtypes = []
......@@ -101,34 +105,38 @@ class SendRecvMeta:
def _send_dims_shape_dtype(self, tensor, group):
# send len(shape)
dims = paddle.to_tensor(len(tensor.shape))
paddle.distributed.send(dims, dst=1, group=group)
dst_rank = group.ranks[1]
paddle.distributed.send(dims, dst=dst_rank, group=group)
# send shape
shape = paddle.to_tensor(tensor.shape)
paddle.distributed.send(shape, dst=1, group=group)
paddle.distributed.send(shape, dst=dst_rank, group=group)
# send dtype
dtype = paddle.to_tensor(paddle_2_number(tensor.dtype))
paddle.distributed.send(dtype, dst=1, group=group)
paddle.distributed.send(dtype, dst=dst_rank, group=group)
# send trainable
stop_grad = paddle.to_tensor(int(tensor.stop_gradient))
paddle.distributed.send(stop_grad, dst=1, group=group)
paddle.distributed.send(stop_grad, dst=dst_rank, group=group)
def send_meta(self, tensor, group):
dst_rank = group.ranks[1]
if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)):
tensor_type = paddle.to_tensor([0])
# send tensor type
paddle.distributed.send(tensor_type, dst=1, group=group)
paddle.distributed.send(tensor_type, dst=dst_rank, group=group)
self._send_dims_shape_dtype(tensor, group)
elif isinstance(tensor, tuple):
tensor_type = paddle.to_tensor([1])
# send tensor type
paddle.distributed.send(tensor_type, dst=1, group=group)
paddle.distributed.send(tensor_type, dst=dst_rank, group=group)
nums = paddle.to_tensor(len(tensor))
paddle.distributed.send(nums, dst=1, group=group)
paddle.distributed.send(nums, dst=dst_rank, group=group)
for d in tensor:
assert isinstance(d, (paddle.Tensor, core.eager.Tensor))
......@@ -166,6 +174,7 @@ def send_partial(tensor,
rank_id=0,
group=None,
use_calc_stream=True):
# dst: local rank in group
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
......@@ -176,7 +185,7 @@ def send_partial(tensor,
dst, 'num', nranks, 'id', rank_id)
else:
return paddle.distributed.send(tensor.detach(),
dst=dst,
dst=group.ranks[dst],
group=group,
use_calc_stream=use_calc_stream)
......@@ -187,6 +196,7 @@ def recv_partial(tensor,
rank_id=0,
group=None,
use_calc_stream=True):
# src: local rank in group
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
......@@ -198,7 +208,7 @@ def recv_partial(tensor,
tensor.shape)
else:
paddle.distributed.recv(tensor.detach(),
src=src,
src=group.ranks[src],
group=group,
use_calc_stream=use_calc_stream)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册