未验证 提交 279ac753 编写于 作者: H Haohongxiang 提交者: GitHub

fix split_tensor of dp_pp_comm_overlap (#54310)

上级 06304ade
...@@ -20,6 +20,7 @@ import paddle ...@@ -20,6 +20,7 @@ import paddle
from paddle import _legacy_C_ops from paddle import _legacy_C_ops
from paddle.distributed.parallel import _split_tensors from paddle.distributed.parallel import _split_tensors
from paddle.fluid import core from paddle.fluid import core
from paddle.framework import base as imperative_base
__all__ = [] __all__ = []
...@@ -165,6 +166,7 @@ class FusedAllReduceBuffer: ...@@ -165,6 +166,7 @@ class FusedAllReduceBuffer:
if self._all_params_checked_in: if self._all_params_checked_in:
self._fused_allreduce_grads() self._fused_allreduce_grads()
@imperative_base.no_grad
def _fused_allreduce_grads(self): def _fused_allreduce_grads(self):
assert self._all_params_checked_in assert self._all_params_checked_in
flattened_vars = [] flattened_vars = []
...@@ -188,6 +190,7 @@ class FusedAllReduceBuffer: ...@@ -188,6 +190,7 @@ class FusedAllReduceBuffer:
) )
) )
@imperative_base.no_grad
def scale_and_split_grads(self): def scale_and_split_grads(self):
for task in self._tasks: for task in self._tasks:
task.wait() task.wait()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册