diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 13788f7e165ea64ee7c2af5dcac06c0d1de0275b..868fb107c3ed9b3845e4375882bd9cb7d6e9b55d 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -671,11 +671,19 @@ class PipelineParallel(MetaParallelBase): self.lr_scheduler.step() def _release_output(self, output): + def can_free(t): + return ( + t is not None + and isinstance(t, paddle.Tensor) + and t.inplace_version == 0 + ) + if isinstance(output, (tuple, list)): for t in output: - if t is not None and isinstance(t, paddle.Tensor): + if can_free(t): t._clear_dataptr() - elif output is not None and isinstance(output, paddle.Tensor): + + elif can_free(output): output._clear_dataptr()