未验证 提交 42431948 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of p2p (#33929)

上级 a74e01ab
......@@ -13,7 +13,6 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
_groups = None
_hcg = None
......@@ -21,7 +20,10 @@ _hcg = None
def initialize_p2p_groups(hcg):
global _groups, _hcg
_groups = [dist.new_group(ranks=group) for group in hcg.get_p2p_groups()]
_groups = [
paddle.distributed.new_group(ranks=group)
for group in hcg.get_p2p_groups()
]
_hcg = hcg
......@@ -33,7 +35,7 @@ def send(tensor, dest_stage):
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
dst_rank = _hcg.get_rank_from_stage(stage_id=dest_stage)
return dist.broadcast(tensor, src_rank, group=group)
return paddle.distributed.broadcast(tensor, src_rank, group=group)
def recv(tensor, src_stage):
......@@ -43,7 +45,7 @@ def recv(tensor, src_stage):
_is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage)
src_rank = _hcg.get_rank_from_stage(stage_id=src_stage)
return dist.broadcast(tensor, src_rank, group=group)
return paddle.distributed.broadcast(tensor, src_rank, group=group)
def _is_valid_communciate(src_stage, dest_stage):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册