Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a56eba3a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a56eba3a
编写于
6月 12, 2023
作者:
S
ShenLiang
提交者:
GitHub
6月 12, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Distributed] Add pipeline opt memory (#54505)
上级
eca64f0f
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
32 addition
and
0 deletion
+32
-0
paddle/fluid/pybind/eager_method.cc
paddle/fluid/pybind/eager_method.cc
+13
-0
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+19
-0
未找到文件。
paddle/fluid/pybind/eager_method.cc
浏览文件 @
a56eba3a
...
@@ -1473,6 +1473,15 @@ static PyObject* tensor__clear(TensorObject* self,
...
@@ -1473,6 +1473,15 @@ static PyObject* tensor__clear(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
EAGER_CATCH_AND_THROW_RETURN_NULL
}
}
static
PyObject
*
tensor__clear_dataptr
(
TensorObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
EAGER_TRY
self
->
tensor
.
set_impl
(
nullptr
);
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static
PyObject
*
tensor__copy_gradient_from
(
TensorObject
*
self
,
static
PyObject
*
tensor__copy_gradient_from
(
TensorObject
*
self
,
PyObject
*
args
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
PyObject
*
kwargs
)
{
...
@@ -2110,6 +2119,10 @@ PyMethodDef variable_methods[] = {
...
@@ -2110,6 +2119,10 @@ PyMethodDef variable_methods[] = {
(
PyCFunction
)(
void
(
*
)(
void
))
tensor__clear
,
(
PyCFunction
)(
void
(
*
)(
void
))
tensor__clear
,
METH_VARARGS
|
METH_KEYWORDS
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
NULL
},
{
"_clear_dataptr"
,
(
PyCFunction
)(
void
(
*
)(
void
))
tensor__clear_dataptr
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"_copy_gradient_from"
,
{
"_copy_gradient_from"
,
(
PyCFunction
)(
void
(
*
)(
void
))
tensor__copy_gradient_from
,
(
PyCFunction
)(
void
(
*
)(
void
))
tensor__copy_gradient_from
,
METH_VARARGS
|
METH_KEYWORDS
,
METH_VARARGS
|
METH_KEYWORDS
,
...
...
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
a56eba3a
...
@@ -259,6 +259,9 @@ class PipelineParallel(MetaParallelBase):
...
@@ -259,6 +259,9 @@ class PipelineParallel(MetaParallelBase):
input_buffers
.
append
(
input_tensor
)
input_buffers
.
append
(
input_tensor
)
output_buffers
.
append
(
output_tensor
)
output_buffers
.
append
(
output_tensor
)
if
not
self
.
is_pipeline_last_stage
():
self
.
_release_output
(
output_tensor
)
if
steady_steps
>
0
:
if
steady_steps
>
0
:
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
...
@@ -274,6 +277,9 @@ class PipelineParallel(MetaParallelBase):
...
@@ -274,6 +277,9 @@ class PipelineParallel(MetaParallelBase):
input_buffers
.
append
(
input_tensor
)
input_buffers
.
append
(
input_tensor
)
output_buffers
.
append
(
output_tensor
)
output_buffers
.
append
(
output_tensor
)
if
not
self
.
is_pipeline_last_stage
():
self
.
_release_output
(
output_tensor
)
input_tensor
,
output_tensor
=
input_buffers
.
pop
(
input_tensor
,
output_tensor
=
input_buffers
.
pop
(
0
0
),
output_buffers
.
pop
(
0
)
),
output_buffers
.
pop
(
0
)
...
@@ -608,6 +614,14 @@ class PipelineParallel(MetaParallelBase):
...
@@ -608,6 +614,14 @@ class PipelineParallel(MetaParallelBase):
if
self
.
lr_scheduler
:
if
self
.
lr_scheduler
:
self
.
lr_scheduler
.
step
()
self
.
lr_scheduler
.
step
()
def
_release_output
(
self
,
output
):
if
isinstance
(
output
,
(
tuple
,
list
)):
for
t
in
output
:
if
t
is
not
None
and
isinstance
(
t
,
paddle
.
Tensor
):
t
.
_clear_dataptr
()
elif
output
is
not
None
and
isinstance
(
output
,
paddle
.
Tensor
):
output
.
_clear_dataptr
()
class
PipelineParallelWithInterleave
(
PipelineParallel
):
class
PipelineParallelWithInterleave
(
PipelineParallel
):
# pipeline parallel with interleave scheduler
# pipeline parallel with interleave scheduler
...
@@ -782,6 +796,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -782,6 +796,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
# append input_tensor no matter none or not
# append input_tensor no matter none or not
self
.
input_tensors
[
next_virtual_pp_rank
].
append
(
input_tensor
)
self
.
input_tensors
[
next_virtual_pp_rank
].
append
(
input_tensor
)
self
.
_release_output
(
output_tensor
)
# run 1f1b steady steps
# run 1f1b steady steps
for
micro_step
in
range
(
steady_steps
):
for
micro_step
in
range
(
steady_steps
):
# forward
# forward
...
@@ -859,6 +875,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -859,6 +875,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
self
.
output_tensor_grads
[
next_backward_virtual_pp_rank
].
append
(
self
.
output_tensor_grads
[
next_backward_virtual_pp_rank
].
append
(
output_tensor_grad
output_tensor_grad
)
)
self
.
_release_output
(
output_tensor
)
self
.
_release_output
(
output_tensor
)
# remaining backward steps
# remaining backward steps
if
not
forward_only
:
if
not
forward_only
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录