未验证 提交 cb1a50f5 编写于 作者: Y Yuang Liu 提交者: GitHub

fix sharding overlap (#54872)

上级 19ffd27d
...@@ -246,17 +246,21 @@ class FusedCommBuffer: ...@@ -246,17 +246,21 @@ class FusedCommBuffer:
def _comm_grads(self): def _comm_grads(self):
assert self._all_params_checked_in assert self._all_params_checked_in
if self._act == HOOK_ACTION.ALL_REDUCE: # Note: after sharding change to reduce operation here also need to be updated
task = paddle.distributed.all_reduce( # if self._act == HOOK_ACTION.ALL_REDUCE:
self.grad_storage, group=self._comm_group, sync_op=False # task = paddle.distributed.all_reduce(
) # self.grad_storage, group=self._comm_group, sync_op=False
elif self._act == HOOK_ACTION.REDUCE: # )
task = paddle.distributed.reduce( # elif self._act == HOOK_ACTION.REDUCE:
self.grad_storage, # task = paddle.distributed.reduce(
dst=self._dst, # self.grad_storage,
group=self._comm_group, # dst=self._dst,
sync_op=False, # group=self._comm_group,
) # sync_op=False,
# )
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
self._task = task self._task = task
@imperative_base.no_grad @imperative_base.no_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册