未验证 提交 2e9a31eb 编写于 作者: L lilong12 提交者: GitHub

Fix bug in alltoall (#34975)

上级 a1373714
...@@ -1524,7 +1524,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): ...@@ -1524,7 +1524,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
inputs={'X': [temp]}, inputs={'X': [temp]},
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={ attrs={
'ring_id': group, 'ring_id': ring_id,
'use_calc_stream': use_calc_stream, 'use_calc_stream': use_calc_stream,
}) })
out_tensor_list.extend(paddle.split(out, nranks, 0)) out_tensor_list.extend(paddle.split(out, nranks, 0))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册