未验证 提交 5a1b6f5d 编写于 作者: S ShenLiang 提交者: GitHub

add p2p (#50337)

上级 913f40ee
...@@ -17,7 +17,11 @@ from ...utils.log_util import logger ...@@ -17,7 +17,11 @@ from ...utils.log_util import logger
import numpy as np import numpy as np
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode from paddle.fluid.framework import (
_in_legacy_dygraph,
_non_static_mode,
in_dygraph_mode,
)
from .utils import paddle_2_number, paddle_2_number, number_2_dtype from .utils import paddle_2_number, paddle_2_number, number_2_dtype
_hcg = None _hcg = None
...@@ -30,12 +34,23 @@ def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True): ...@@ -30,12 +34,23 @@ def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
_hcg = hcg _hcg = hcg
_use_cache = use_cache _use_cache = use_cache
_enable_partial_send_recv = enable_partial_send_recv _enable_partial_send_recv = enable_partial_send_recv
send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups( (
send_next_group,
send_prev_group,
recv_next_group,
recv_prev_group,
) = _hcg.get_p2p_groups()
debug_str = (
"P2pInfo: send_next_group: %s, send_prev_group: %s, "
"recv_next_group: %s, recv_prev_group: %s"
% (
repr(send_next_group),
repr(send_prev_group),
repr(recv_next_group),
repr(recv_prev_group),
)
) )
debug_str = "P2pInfo: send_next_group: %s, send_prev_group: %s, " \
"recv_next_group: %s, recv_prev_group: %s" % (repr(send_next_group),
repr(send_prev_group),repr(recv_next_group), repr(recv_prev_group))
logger.info(debug_str) logger.info(debug_str)
...@@ -150,9 +165,15 @@ class SendRecvMeta: ...@@ -150,9 +165,15 @@ class SendRecvMeta:
self.send_dtype_message = paddle_2_number(tensor.dtype) self.send_dtype_message = paddle_2_number(tensor.dtype)
elif isinstance(tensor, tuple): elif isinstance(tensor, tuple):
self.send_shape_message = tuple( self.send_shape_message = tuple(
[d.shape for d in tensor if not d.stop_gradient]) [d.shape for d in tensor if not d.stop_gradient]
)
self.send_dtype_message = tuple( self.send_dtype_message = tuple(
[paddle_2_number(d.dtype) for d in tensor]) [
paddle_2_number(d.dtype)
for d in tensor
if not d.stop_gradient
]
)
_send_recv_meta = SendRecvMeta() _send_recv_meta = SendRecvMeta()
...@@ -166,84 +187,117 @@ def _is_valid_send_recv_partial(tensor, mp_degree): ...@@ -166,84 +187,117 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
return mp_degree > 1 and tensor_numel % mp_degree == 0 return mp_degree > 1 and tensor_numel % mp_degree == 0
def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, def _partial_send_op(
rank_id): tensor, group, use_calc_stream, ring_id, dst, nranks, rank_id
):
dst_rank_in_group = dst if group is None else group.get_group_rank(dst) dst_rank_in_group = dst if group is None else group.get_group_rank(dst)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream', return _legacy_C_ops.partial_send(
use_calc_stream, 'ring_id', ring_id, tensor.detach(),
'peer', dst_rank_in_group, 'num', 'use_calc_stream',
nranks, 'id', rank_id) use_calc_stream,
'ring_id',
ring_id,
'peer',
dst_rank_in_group,
'num',
nranks,
'id',
rank_id,
)
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = (
) if group is None else group paddle.distributed.collective._get_default_group()
comm_op = group.process_group.send_partial_on_calc_stream \ if group is None
if use_calc_stream else group.process_group.send_partial else group
)
comm_op = (
group.process_group.send_partial_on_calc_stream
if use_calc_stream
else group.process_group.send_partial
)
return comm_op(tensor, dst_rank_in_group, nranks, rank_id) return comm_op(tensor, dst_rank_in_group, nranks, rank_id)
def send_partial(tensor, def send_partial(
dst=0, tensor, dst=0, nranks=1, rank_id=0, group=None, use_calc_stream=True
nranks=1, ):
rank_id=0,
group=None,
use_calc_stream=True):
# dst: local rank in group # dst: local rank in group
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
dst_rank = _hcg._get_p2p_next_rank( dst_rank = (
) if dst == 1 else _hcg._get_p2p_prev_rank() _hcg._get_p2p_next_rank() if dst == 1 else _hcg._get_p2p_prev_rank()
)
if _is_valid_send_recv_partial(tensor, nranks): if _is_valid_send_recv_partial(tensor, nranks):
return _partial_send_op(tensor, group, use_calc_stream, ring_id, return _partial_send_op(
dst_rank, nranks, rank_id) tensor, group, use_calc_stream, ring_id, dst_rank, nranks, rank_id
)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
send_op = lambda x, dst, group: \ send_op = lambda x, dst, group: paddle.distributed.send(
paddle.distributed.send(x, dst, group, use_calc_stream) x, dst, group, use_calc_stream
)
elif in_dygraph_mode(): elif in_dygraph_mode():
send_op = paddle.distributed.isend send_op = paddle.distributed.isend
return send_op(tensor.detach(), dst=dst_rank, group=group) return send_op(tensor.detach(), dst=dst_rank, group=group)
def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, def _partial_recv_op(
rank_id): tensor, group, use_calc_stream, ring_id, src, nranks, rank_id
):
src_rank_in_group = src if group is None else group.get_group_rank(src) src_rank_in_group = src if group is None else group.get_group_rank(src)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
assert use_calc_stream assert use_calc_stream
return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', return _legacy_C_ops.partial_recv(
use_calc_stream, 'ring_id', ring_id, tensor.detach(),
'peer', src_rank_in_group, 'num', 'use_calc_stream',
nranks, 'id', rank_id, 'dtype', use_calc_stream,
tensor.dtype, 'out_shape', 'ring_id',
tensor.shape) ring_id,
'peer',
src_rank_in_group,
'num',
nranks,
'id',
rank_id,
'dtype',
tensor.dtype,
'out_shape',
tensor.shape,
)
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = (
) if group is None else group paddle.distributed.collective._get_default_group()
comm_op = group.process_group.recv_partial_on_calc_stream \ if group is None
if use_calc_stream else group.process_group.recv_partial else group
)
comm_op = (
group.process_group.recv_partial_on_calc_stream
if use_calc_stream
else group.process_group.recv_partial
)
return comm_op(tensor, src_rank_in_group, nranks, rank_id) return comm_op(tensor, src_rank_in_group, nranks, rank_id)
def recv_partial(tensor, def recv_partial(
src=0, tensor, src=0, nranks=1, rank_id=0, group=None, use_calc_stream=True
nranks=1, ):
rank_id=0,
group=None,
use_calc_stream=True):
# src: local rank in group # src: local rank in group
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
src_rank = _hcg._get_p2p_prev_rank( src_rank = (
) if src == 0 else _hcg._get_p2p_next_rank() _hcg._get_p2p_prev_rank() if src == 0 else _hcg._get_p2p_next_rank()
)
if _is_valid_send_recv_partial(tensor, nranks): if _is_valid_send_recv_partial(tensor, nranks):
return _partial_recv_op(tensor, group, use_calc_stream, ring_id, return _partial_recv_op(
src_rank, nranks, rank_id) tensor, group, use_calc_stream, ring_id, src_rank, nranks, rank_id
)
else: else:
if _in_legacy_dygraph() or use_calc_stream: if _in_legacy_dygraph() or use_calc_stream:
recv_op = paddle.distributed.recv recv_op = paddle.distributed.recv
...@@ -252,42 +306,52 @@ def recv_partial(tensor, ...@@ -252,42 +306,52 @@ def recv_partial(tensor,
return recv_op(tensor.detach(), src=src_rank, group=group) return recv_op(tensor.detach(), src=src_rank, group=group)
def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, def _partial_allgather_op(
rank_id): tensor, group, use_calc_stream, ring_id, nranks, rank_id
):
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.partial_allgather_(tensor.detach(), return _legacy_C_ops.partial_allgather_(
tensor.detach(),
'use_calc_stream', 'use_calc_stream',
use_calc_stream, 'ring_id', use_calc_stream,
ring_id, 'nranks', nranks, 'ring_id',
'rank', rank_id) ring_id,
'nranks',
nranks,
'rank',
rank_id,
)
elif in_dygraph_mode(): elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group( group = (
) if group is None else group paddle.distributed.collective._get_default_group()
comm_op = group.process_group.all_gather_partial_on_calc_stream \ if group is None
if use_calc_stream else group.process_group.all_gather_partial else group
)
comm_op = (
group.process_group.all_gather_partial_on_calc_stream
if use_calc_stream
else group.process_group.all_gather_partial
)
return comm_op(tensor, tensor, nranks, rank_id) return comm_op(tensor, tensor, nranks, rank_id)
def allgather_partial(tensor, def allgather_partial(
nranks=1, tensor, nranks=1, rank_id=0, group=None, use_calc_stream=True
rank_id=0, ):
group=None,
use_calc_stream=True):
if not _is_valid_send_recv_partial(tensor, nranks): if not _is_valid_send_recv_partial(tensor, nranks):
return tensor return tensor
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
return _partial_allgather_op(tensor, group, use_calc_stream, ring_id, return _partial_allgather_op(
nranks, rank_id) tensor, group, use_calc_stream, ring_id, nranks, rank_id
)
def _p2p_helper(tensor_send_next, def _p2p_helper(
tensor_send_prev, tensor_send_next, tensor_send_prev, recv_prev, recv_next, sync_recv=True
recv_prev, ):
recv_next,
sync_recv=True):
global _hcg global _hcg
tensor_recv_prev = None tensor_recv_prev = None
...@@ -310,15 +374,17 @@ def _p2p_helper(tensor_send_next, ...@@ -310,15 +374,17 @@ def _p2p_helper(tensor_send_next,
if isinstance(recv_shape_msg, tuple): if isinstance(recv_shape_msg, tuple):
tensor_recv_prev = [] tensor_recv_prev = []
for idx, shape in enumerate(recv_shape_msg): for idx, shape in enumerate(recv_shape_msg):
tmp = paddle.empty(shape=shape, tmp = paddle.empty(
dtype=number_2_dtype(recv_dtype_msg[idx])) shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx])
)
tmp.stop_gradient = recv_stop_gradient[idx] tmp.stop_gradient = recv_stop_gradient[idx]
tensor_recv_prev.append(tmp) tensor_recv_prev.append(tmp)
tensor_recv_prev = tuple(tensor_recv_prev) tensor_recv_prev = tuple(tensor_recv_prev)
else: else:
tensor_recv_prev = paddle.empty( tensor_recv_prev = paddle.empty(
shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg)) shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg)
)
tensor_recv_prev.stop_gradient = recv_stop_gradient tensor_recv_prev.stop_gradient = recv_stop_gradient
if recv_next: if recv_next:
...@@ -326,12 +392,15 @@ def _p2p_helper(tensor_send_next, ...@@ -326,12 +392,15 @@ def _p2p_helper(tensor_send_next,
tensor_recv_next = [] tensor_recv_next = []
for idx, shape in enumerate(send_shape_msg): for idx, shape in enumerate(send_shape_msg):
tensor_recv_next.append( tensor_recv_next.append(
paddle.empty(shape=shape, paddle.empty(
dtype=number_2_dtype(send_dtype_msg[idx]))) shape=shape, dtype=number_2_dtype(send_dtype_msg[idx])
)
)
tensor_recv_next = tuple(tensor_recv_next) tensor_recv_next = tuple(tensor_recv_next)
else: else:
tensor_recv_next = paddle.empty( tensor_recv_next = paddle.empty(
shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg)) shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg)
)
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops # TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
tasks = [] tasks = []
...@@ -340,51 +409,63 @@ def _p2p_helper(tensor_send_next, ...@@ -340,51 +409,63 @@ def _p2p_helper(tensor_send_next,
if isinstance(tensor_send_prev, tuple): if isinstance(tensor_send_prev, tuple):
for d in tensor_send_prev: for d in tensor_send_prev:
paddle.distributed.wait(d, use_calc_stream=True) paddle.distributed.wait(d, use_calc_stream=True)
send_partial(d, send_partial(
d,
dst=0, dst=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.send_prev_group, group=_hcg.send_prev_group,
use_calc_stream=False) use_calc_stream=False,
)
else: else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
send_partial(tensor_send_prev, send_partial(
tensor_send_prev,
dst=0, dst=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.send_prev_group, group=_hcg.send_prev_group,
use_calc_stream=False) use_calc_stream=False,
)
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
if isinstance(tensor_recv_prev, tuple): if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev: for d in tensor_recv_prev:
task = recv_partial(d, task = recv_partial(
d,
src=0, src=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_prev_group, group=_hcg.recv_prev_group,
use_calc_stream=sync_recv) use_calc_stream=sync_recv,
)
if sync_recv: if sync_recv:
allgather_partial(d, allgather_partial(
d,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
use_calc_stream=True) use_calc_stream=True,
)
else: else:
tasks.append(task) tasks.append(task)
else: else:
task = recv_partial(tensor_recv_prev, task = recv_partial(
tensor_recv_prev,
src=0, src=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_prev_group, group=_hcg.recv_prev_group,
use_calc_stream=sync_recv) use_calc_stream=sync_recv,
)
if sync_recv: if sync_recv:
allgather_partial(tensor_recv_prev, allgather_partial(
tensor_recv_prev,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
use_calc_stream=True) use_calc_stream=True,
)
else: else:
tasks.append(task) tasks.append(task)
...@@ -392,52 +473,64 @@ def _p2p_helper(tensor_send_next, ...@@ -392,52 +473,64 @@ def _p2p_helper(tensor_send_next,
if isinstance(tensor_send_next, tuple): if isinstance(tensor_send_next, tuple):
for d in tensor_send_next: for d in tensor_send_next:
paddle.distributed.wait(d, use_calc_stream=True) paddle.distributed.wait(d, use_calc_stream=True)
send_partial(d, send_partial(
d,
dst=1, dst=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.send_next_group, group=_hcg.send_next_group,
use_calc_stream=False) use_calc_stream=False,
)
else: else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True) paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial(tensor_send_next, send_partial(
tensor_send_next,
dst=1, dst=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.send_next_group, group=_hcg.send_next_group,
use_calc_stream=False) use_calc_stream=False,
)
if tensor_recv_next is not None: if tensor_recv_next is not None:
if isinstance(tensor_recv_next, tuple): if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next: for d in tensor_recv_next:
task = recv_partial(d, task = recv_partial(
d,
src=1, src=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_next_group, group=_hcg.recv_next_group,
use_calc_stream=sync_recv) use_calc_stream=sync_recv,
)
if sync_recv: if sync_recv:
allgather_partial(d, allgather_partial(
d,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
use_calc_stream=True) use_calc_stream=True,
)
else: else:
tasks.append(task) tasks.append(task)
else: else:
task = recv_partial(tensor_recv_next, task = recv_partial(
tensor_recv_next,
src=1, src=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_next_group, group=_hcg.recv_next_group,
use_calc_stream=sync_recv) use_calc_stream=sync_recv,
)
if sync_recv: if sync_recv:
allgather_partial(tensor_recv_next, allgather_partial(
tensor_recv_next,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
use_calc_stream=True) use_calc_stream=True,
)
else: else:
tasks.append(task) tasks.append(task)
...@@ -463,11 +556,13 @@ def _p2p_helper(tensor_send_next, ...@@ -463,11 +556,13 @@ def _p2p_helper(tensor_send_next,
tensors_for_all_gather.append(tensor_recv_next) tensors_for_all_gather.append(tensor_recv_next)
for tensor in tensors_for_all_gather: for tensor in tensors_for_all_gather:
allgather_partial(tensor, allgather_partial(
tensor,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
use_calc_stream=True) use_calc_stream=True,
)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
...@@ -480,11 +575,13 @@ def recv_forward(pp_first_stage, sync_recv=True): ...@@ -480,11 +575,13 @@ def recv_forward(pp_first_stage, sync_recv=True):
_send_recv_meta.recv_meta(_hcg.recv_prev_group) _send_recv_meta.recv_meta(_hcg.recv_prev_group)
_send_recv_meta.has_recv_meta = _use_cache _send_recv_meta.has_recv_meta = _use_cache
input_tensor, _ = _p2p_helper(tensor_send_next=None, input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False,
sync_recv=sync_recv) sync_recv=sync_recv,
)
return input_tensor return input_tensor
...@@ -492,11 +589,13 @@ def recv_backward(pp_last_stage, sync_recv=True): ...@@ -492,11 +589,13 @@ def recv_backward(pp_last_stage, sync_recv=True):
if pp_last_stage: if pp_last_stage:
output_tensor_grad = None output_tensor_grad = None
else: else:
_, output_tensor_grad = _p2p_helper(tensor_send_next=None, _, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True, recv_next=True,
sync_recv=sync_recv) sync_recv=sync_recv,
)
return output_tensor_grad return output_tensor_grad
...@@ -507,28 +606,34 @@ def send_forward(output_tensor, pp_last_stage): ...@@ -507,28 +606,34 @@ def send_forward(output_tensor, pp_last_stage):
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
_send_recv_meta.has_send_meta = _use_cache _send_recv_meta.has_send_meta = _use_cache
_p2p_helper(tensor_send_next=output_tensor, _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=False) recv_next=False,
)
def send_backward(input_tensor_grad, pp_first_stage): def send_backward(input_tensor_grad, pp_first_stage):
if not pp_first_stage: if not pp_first_stage:
_p2p_helper(tensor_send_next=None, _p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=False) recv_next=False,
)
def send_forward_recv_backward(output_tensor, pp_last_stage): def send_forward_recv_backward(output_tensor, pp_last_stage):
if pp_last_stage: if pp_last_stage:
output_tensor_grad = None output_tensor_grad = None
else: else:
_, output_tensor_grad = _p2p_helper(tensor_send_next=output_tensor, _, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True) recv_next=True,
)
return output_tensor_grad return output_tensor_grad
...@@ -536,16 +641,18 @@ def send_backward_recv_forward(input_tensor_grad, pp_first_stage): ...@@ -536,16 +641,18 @@ def send_backward_recv_forward(input_tensor_grad, pp_first_stage):
if pp_first_stage: if pp_first_stage:
input_tensor = None input_tensor = None
else: else:
input_tensor, _ = _p2p_helper(tensor_send_next=None, input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=True, recv_prev=True,
recv_next=False) recv_next=False,
)
return input_tensor return input_tensor
def send_forward_backward_recv_forward_backward(output_tensor, def send_forward_backward_recv_forward_backward(
input_tensor_grad, recv_prev, output_tensor, input_tensor_grad, recv_prev, recv_next
recv_next): ):
# always have to send dytpe info to downstream # always have to send dytpe info to downstream
if not _send_recv_meta.has_send_meta: if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor) _send_recv_meta.set_send_message(output_tensor)
...@@ -559,7 +666,8 @@ def send_forward_backward_recv_forward_backward(output_tensor, ...@@ -559,7 +666,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next, recv_next=recv_next,
sync_recv=False) sync_recv=False,
)
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
...@@ -573,19 +681,23 @@ def send_forward_recv_forward(output_tensor, recv_prev): ...@@ -573,19 +681,23 @@ def send_forward_recv_forward(output_tensor, recv_prev):
_send_recv_meta.recv_meta(_hcg.recv_prev_group) _send_recv_meta.recv_meta(_hcg.recv_prev_group)
_send_recv_meta.has_recv_meta = _use_cache _send_recv_meta.has_recv_meta = _use_cache
input_tensor, _ = _p2p_helper(tensor_send_next=output_tensor, input_tensor, _ = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=False, recv_next=False,
sync_recv=False) sync_recv=False,
)
return input_tensor return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next): def send_backward_recv_backward(input_tensor_grad, recv_next):
_, output_tensor_grad = _p2p_helper(tensor_send_next=None, _, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=recv_next, recv_next=recv_next,
sync_recv=False) sync_recv=False,
)
return output_tensor_grad return output_tensor_grad
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册