未验证 提交 279ac753 编写于 作者: H Haohongxiang 提交者: GitHub

fix split_tensor of dp_pp_comm_overlap (#54310)

上级 06304ade
......@@ -20,6 +20,7 @@ import paddle
from paddle import _legacy_C_ops
from paddle.distributed.parallel import _split_tensors
from paddle.fluid import core
from paddle.framework import base as imperative_base
__all__ = []
......@@ -165,6 +166,7 @@ class FusedAllReduceBuffer:
if self._all_params_checked_in:
self._fused_allreduce_grads()
@imperative_base.no_grad
def _fused_allreduce_grads(self):
assert self._all_params_checked_in
flattened_vars = []
......@@ -188,6 +190,7 @@ class FusedAllReduceBuffer:
)
)
@imperative_base.no_grad
def scale_and_split_grads(self):
for task in self._tasks:
task.wait()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册