未验证 提交 9b5fa328 编写于 作者: S ShenLiang 提交者: GitHub

[Cherry-Pick]Add pipeline opt memory (#54557)

* [Distributed] Add pipeline opt memory (#54505)

* fix bug of release output in pp
上级 a5d9f244
...@@ -1473,6 +1473,15 @@ static PyObject* tensor__clear(TensorObject* self, ...@@ -1473,6 +1473,15 @@ static PyObject* tensor__clear(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor__clear_dataptr(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
self->tensor.set_impl(nullptr);
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__copy_gradient_from(TensorObject* self, static PyObject* tensor__copy_gradient_from(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
...@@ -2110,6 +2119,10 @@ PyMethodDef variable_methods[] = { ...@@ -2110,6 +2119,10 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)(void))tensor__clear, (PyCFunction)(void (*)(void))tensor__clear,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
{"_clear_dataptr",
(PyCFunction)(void (*)(void))tensor__clear_dataptr,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_copy_gradient_from", {"_copy_gradient_from",
(PyCFunction)(void (*)(void))tensor__copy_gradient_from, (PyCFunction)(void (*)(void))tensor__copy_gradient_from,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
......
...@@ -258,6 +258,9 @@ class PipelineParallel(MetaParallelBase): ...@@ -258,6 +258,9 @@ class PipelineParallel(MetaParallelBase):
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
output_buffers.append(output_tensor) output_buffers.append(output_tensor)
if not self.is_pipeline_last_stage():
self._release_output(output_tensor)
if steady_steps > 0: if steady_steps > 0:
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
...@@ -273,6 +276,9 @@ class PipelineParallel(MetaParallelBase): ...@@ -273,6 +276,9 @@ class PipelineParallel(MetaParallelBase):
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
output_buffers.append(output_tensor) output_buffers.append(output_tensor)
if not self.is_pipeline_last_stage():
self._release_output(output_tensor)
input_tensor, output_tensor = input_buffers.pop( input_tensor, output_tensor = input_buffers.pop(
0 0
), output_buffers.pop(0) ), output_buffers.pop(0)
...@@ -607,6 +613,22 @@ class PipelineParallel(MetaParallelBase): ...@@ -607,6 +613,22 @@ class PipelineParallel(MetaParallelBase):
if self.lr_scheduler: if self.lr_scheduler:
self.lr_scheduler.step() self.lr_scheduler.step()
def _release_output(self, output):
def can_free(t):
return (
t is not None
and isinstance(t, paddle.Tensor)
and t.inplace_version == 0
)
if isinstance(output, (tuple, list)):
for t in output:
if can_free(t):
t._clear_dataptr()
elif can_free(output):
output._clear_dataptr()
class PipelineParallelWithInterleave(PipelineParallel): class PipelineParallelWithInterleave(PipelineParallel):
# pipeline parallel with interleave scheduler # pipeline parallel with interleave scheduler
...@@ -777,6 +799,8 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -777,6 +799,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
# append input_tensor no matter none or not # append input_tensor no matter none or not
self.input_tensors[next_virtual_pp_rank].append(input_tensor) self.input_tensors[next_virtual_pp_rank].append(input_tensor)
self._release_output(output_tensor)
# run 1f1b steady steps # run 1f1b steady steps
for micro_step in range(steady_steps): for micro_step in range(steady_steps):
# forward # forward
...@@ -854,6 +878,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -854,6 +878,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.output_tensor_grads[next_backward_virtual_pp_rank].append( self.output_tensor_grads[next_backward_virtual_pp_rank].append(
output_tensor_grad output_tensor_grad
) )
self._release_output(output_tensor)
self._release_output(output_tensor)
# remaining backward steps # remaining backward steps
if not forward_only: if not forward_only:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册