未验证 提交 cb1a50f5 编写于 作者: Y Yuang Liu 提交者: GitHub

fix sharding overlap (#54872)

上级 19ffd27d
......@@ -246,17 +246,21 @@ class FusedCommBuffer:
def _comm_grads(self):
assert self._all_params_checked_in
if self._act == HOOK_ACTION.ALL_REDUCE:
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
self.grad_storage,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)
# Note: after sharding change to reduce operation here also need to be updated
# if self._act == HOOK_ACTION.ALL_REDUCE:
# task = paddle.distributed.all_reduce(
# self.grad_storage, group=self._comm_group, sync_op=False
# )
# elif self._act == HOOK_ACTION.REDUCE:
# task = paddle.distributed.reduce(
# self.grad_storage,
# dst=self._dst,
# group=self._comm_group,
# sync_op=False,
# )
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
self._task = task
@imperative_base.no_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册