diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index a3f840cf770668a1a5f666c0fd4df34d54975901..3bc9fc1ccf2aa33ec0114fc67e714d922bb890f7 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -673,11 +673,20 @@ class PipelineParallel(MetaParallelBase): self.lr_scheduler.step() def _release_output(self, output): + def can_free(t): + return ( + t is not None + and isinstance(t, paddle.Tensor) + and t._is_initialized() + and t.inplace_version == 0 + ) + if isinstance(output, (tuple, list)): for t in output: - if t is not None and isinstance(t, paddle.Tensor): + if can_free(t): t._clear_dataptr() - elif output is not None and isinstance(output, paddle.Tensor): + + elif can_free(output): output._clear_dataptr()