From 40bfe0eb335107084ea445e19e225ca3b5f97faa Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 14 Jun 2023 10:58:07 +0800 Subject: [PATCH] fix bug of release output in pp (#54624) --- .../fleet/meta_parallel/pipeline_parallel.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 13788f7e165..868fb107c3e 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -671,11 +671,19 @@ class PipelineParallel(MetaParallelBase): self.lr_scheduler.step() def _release_output(self, output): + def can_free(t): + return ( + t is not None + and isinstance(t, paddle.Tensor) + and t.inplace_version == 0 + ) + if isinstance(output, (tuple, list)): for t in output: - if t is not None and isinstance(t, paddle.Tensor): + if can_free(t): t._clear_dataptr() - elif output is not None and isinstance(output, paddle.Tensor): + + elif can_free(output): output._clear_dataptr() -- GitLab