From ba75fbec1bf5f479336d09d6a0df872b7580b19a Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Fri, 16 Jun 2023 12:07:32 +0800 Subject: [PATCH] [BugFix] fix bug of release output in pp (#54625) --- .../fleet/meta_parallel/pipeline_parallel.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index a3f840cf770..3bc9fc1ccf2 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -673,11 +673,20 @@ class PipelineParallel(MetaParallelBase): self.lr_scheduler.step() def _release_output(self, output): + def can_free(t): + return ( + t is not None + and isinstance(t, paddle.Tensor) + and t._is_initialized() + and t.inplace_version == 0 + ) + if isinstance(output, (tuple, list)): for t in output: - if t is not None and isinstance(t, paddle.Tensor): + if can_free(t): t._clear_dataptr() - elif output is not None and isinstance(output, paddle.Tensor): + + elif can_free(output): output._clear_dataptr() -- GitLab