未验证 提交 0cc46b4a 编写于 作者: S ShenLiang 提交者: GitHub

[Distributed]Opt memory pp & vp (#54325)

* opt memory

* rm args

* add rm ptr
上级 05135113
......@@ -1462,6 +1462,15 @@ static PyObject* tensor__clear(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__clear_dataptr(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
self->tensor.set_impl(nullptr);
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__copy_gradient_from(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
......@@ -2099,6 +2108,10 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)(void))tensor__clear,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_clear_dataptr",
(PyCFunction)(void (*)(void))tensor__clear_dataptr,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_copy_gradient_from",
(PyCFunction)(void (*)(void))tensor__copy_gradient_from,
METH_VARARGS | METH_KEYWORDS,
......
......@@ -258,6 +258,9 @@ class PipelineParallel(MetaParallelBase):
input_buffers.append(input_tensor)
output_buffers.append(output_tensor)
if not self.is_pipeline_last_stage():
self._release_output(output_tensor)
if steady_steps > 0:
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
......@@ -273,6 +276,9 @@ class PipelineParallel(MetaParallelBase):
input_buffers.append(input_tensor)
output_buffers.append(output_tensor)
if not self.is_pipeline_last_stage():
self._release_output(output_tensor)
input_tensor, output_tensor = input_buffers.pop(
0
), output_buffers.pop(0)
......@@ -607,6 +613,14 @@ class PipelineParallel(MetaParallelBase):
if self.lr_scheduler:
self.lr_scheduler.step()
def _release_output(self, output):
if isinstance(output, (tuple, list)):
for t in output:
if t is not None and isinstance(t, paddle.Tensor):
t._clear_dataptr()
elif output is not None and isinstance(output, paddle.Tensor):
output._clear_dataptr()
class PipelineParallelWithInterleave(PipelineParallel):
# pipeline parallel with interleave scheduler
......@@ -782,6 +796,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
)
self.input_tensors[next_virtual_pp_rank].append(input_tensor)
self._release_output(output_tensor)
# run 1f1b steady steps
for micro_step in range(steady_steps):
# forward
......@@ -861,6 +877,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
recv_next=recv_next,
)
self._release_output(output_tensor)
if recv_prev:
self.input_tensors[next_forward_virtual_pp_rank].append(
input_tensor
......@@ -870,6 +888,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
output_tensor_grad
)
self._release_output(output_tensor)
# remaining backward steps
if not forward_only:
if all_startup_steps:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册