提交 6972bfde 编写于 作者: M Megvii Engine Team

fix(mge/distributed): fix input comp_graph of broadcast operator

GitOrigin-RevId: 039fd06a933b4d1e9a74fbb5a3eaf467cc486178
上级 66950a4f
......@@ -112,26 +112,20 @@ def broadcast(
rank = get_rank()
if rank == root:
return _collective_comm(
tensor,
key,
CollParam.Mode.BROADCAST,
nr_ranks,
rank,
root,
device=tensor.device,
)
inp = tensor
else:
return _collective_comm(
get_default_graph(),
key,
CollParam.Mode.BROADCAST,
nr_ranks,
rank,
root,
dtype=tensor._symvar.dtype,
device=tensor.device,
)
inp = tensor._symvar.owner_graph
return _collective_comm(
inp,
key,
CollParam.Mode.BROADCAST,
nr_ranks,
rank,
root,
dtype=tensor.dtype,
device=tensor.device,
)
def all_gather(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册