diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index e78c443e33a5ed96417f7cd6067434f57e5c24a9..4c88cf30223c78afc7b5342cee5a6e0c56564637 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1473,6 +1473,15 @@ static PyObject* tensor__clear(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor__clear_dataptr(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + EAGER_TRY + self->tensor.set_impl(nullptr); + RETURN_PY_NONE + EAGER_CATCH_AND_THROW_RETURN_NULL +} + static PyObject* tensor__copy_gradient_from(TensorObject* self, PyObject* args, PyObject* kwargs) { @@ -2110,6 +2119,10 @@ PyMethodDef variable_methods[] = { (PyCFunction)(void (*)(void))tensor__clear, METH_VARARGS | METH_KEYWORDS, NULL}, + {"_clear_dataptr", + (PyCFunction)(void (*)(void))tensor__clear_dataptr, + METH_VARARGS | METH_KEYWORDS, + NULL}, {"_copy_gradient_from", (PyCFunction)(void (*)(void))tensor__copy_gradient_from, METH_VARARGS | METH_KEYWORDS, diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index f898b2bb5077efe6cc0f119614e6758838a34660..91d79206afed82a34681d69b3e12ddf6192cfdb6 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -259,6 +259,9 @@ class PipelineParallel(MetaParallelBase): input_buffers.append(input_tensor) output_buffers.append(output_tensor) + if not self.is_pipeline_last_stage(): + self._release_output(output_tensor) + if steady_steps > 0: input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) @@ -274,6 +277,9 @@ class PipelineParallel(MetaParallelBase): input_buffers.append(input_tensor) output_buffers.append(output_tensor) + if not self.is_pipeline_last_stage(): + self._release_output(output_tensor) + input_tensor, output_tensor = input_buffers.pop( 0 ), output_buffers.pop(0) @@ -608,6 +614,14 @@ class PipelineParallel(MetaParallelBase): if self.lr_scheduler: self.lr_scheduler.step() + def _release_output(self, output): + if isinstance(output, (tuple, list)): + for t in output: + if t is not None and isinstance(t, paddle.Tensor): + t._clear_dataptr() + elif output is not None and isinstance(output, paddle.Tensor): + output._clear_dataptr() + class PipelineParallelWithInterleave(PipelineParallel): # pipeline parallel with interleave scheduler @@ -782,6 +796,8 @@ class PipelineParallelWithInterleave(PipelineParallel): # append input_tensor no matter none or not self.input_tensors[next_virtual_pp_rank].append(input_tensor) + self._release_output(output_tensor) + # run 1f1b steady steps for micro_step in range(steady_steps): # forward @@ -859,6 +875,9 @@ class PipelineParallelWithInterleave(PipelineParallel): self.output_tensor_grads[next_backward_virtual_pp_rank].append( output_tensor_grad ) + self._release_output(output_tensor) + + self._release_output(output_tensor) # remaining backward steps if not forward_only: