未验证 提交 42431948 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of p2p (#33929)

上级 a74e01ab
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
import paddle.distributed as dist
_groups = None _groups = None
_hcg = None _hcg = None
...@@ -21,7 +20,10 @@ _hcg = None ...@@ -21,7 +20,10 @@ _hcg = None
def initialize_p2p_groups(hcg): def initialize_p2p_groups(hcg):
global _groups, _hcg global _groups, _hcg
_groups = [dist.new_group(ranks=group) for group in hcg.get_p2p_groups()] _groups = [
paddle.distributed.new_group(ranks=group)
for group in hcg.get_p2p_groups()
]
_hcg = hcg _hcg = hcg
...@@ -33,7 +35,7 @@ def send(tensor, dest_stage): ...@@ -33,7 +35,7 @@ def send(tensor, dest_stage):
_is_valid_communciate(src_stage, dest_stage) _is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage) group = _get_send_recv_group(src_stage, dest_stage)
dst_rank = _hcg.get_rank_from_stage(stage_id=dest_stage) dst_rank = _hcg.get_rank_from_stage(stage_id=dest_stage)
return dist.broadcast(tensor, src_rank, group=group) return paddle.distributed.broadcast(tensor, src_rank, group=group)
def recv(tensor, src_stage): def recv(tensor, src_stage):
...@@ -43,7 +45,7 @@ def recv(tensor, src_stage): ...@@ -43,7 +45,7 @@ def recv(tensor, src_stage):
_is_valid_communciate(src_stage, dest_stage) _is_valid_communciate(src_stage, dest_stage)
group = _get_send_recv_group(src_stage, dest_stage) group = _get_send_recv_group(src_stage, dest_stage)
src_rank = _hcg.get_rank_from_stage(stage_id=src_stage) src_rank = _hcg.get_rank_from_stage(stage_id=src_stage)
return dist.broadcast(tensor, src_rank, group=group) return paddle.distributed.broadcast(tensor, src_rank, group=group)
def _is_valid_communciate(src_stage, dest_stage): def _is_valid_communciate(src_stage, dest_stage):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册