未验证 提交 6a486ec2 编写于 作者: L lilong12 提交者: GitHub

update (#41636)

上级 45e43dfe
...@@ -849,7 +849,9 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): ...@@ -849,7 +849,9 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
if in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
out = paddle.concat(tensor_list) tensor_shape = list(tensor.shape)
tensor_shape[0] *= group.nranks
out = paddle.empty(tensor_shape, tensor.dtype)
task = group.process_group.all_gather(tensor, out) task = group.process_group.all_gather(tensor, out)
task.wait() task.wait()
tensor_list.clear() tensor_list.clear()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册