Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7df043ec
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7df043ec
编写于
6月 13, 2023
作者:
zhenhailiu
提交者:
GitHub
6月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pipeline model 移除 self.data (#54387)
* polish * polish * polish * polish * polish * polish
上级
161dad50
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
128 addition
and
68 deletion
+128
-68
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+128
-68
未找到文件。
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
7df043ec
...
@@ -30,6 +30,85 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
...
@@ -30,6 +30,85 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__
=
[]
__all__
=
[]
# assume only the first stage and last stage need data, and data consumption are ordred;
# to be replaced by real micro dataset from reader
class
FakeMicroDataset
:
def
__init__
(
self
,
data
,
is_first_stage
,
is_last_stage
,
acc_steps
,
micro_batch_size
):
self
.
_data
=
data
self
.
_index
=
0
self
.
_acc_steps
=
acc_steps
self
.
_is_first_stage
=
is_first_stage
self
.
_is_last_stage
=
is_last_stage
self
.
_micro_batch_size
=
micro_batch_size
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
assert
self
.
_index
<
self
.
_acc_steps
assert
self
.
_is_first_stage
or
self
.
_is_last_stage
micro_batch_data
=
self
.
_load_micro_batch
(
self
.
_index
)
self
.
_index
+=
1
return
micro_batch_data
def
_load_micro_batch
(
self
,
micro_step
):
inputs
=
self
.
_data
if
self
.
_is_first_stage
or
self
.
_is_last_stage
:
assert
len
(
inputs
)
==
2
,
"length of input should be 2"
data
=
self
.
_load_micro_batch_impl
(
inputs
[
0
],
micro_step
)
label
=
self
.
_load_micro_batch_impl
(
inputs
[
1
],
micro_step
)
return
(
data
,
label
)
else
:
return
(
None
,
None
)
def
_load_micro_batch_impl
(
self
,
inputs
,
micro_step
):
begin
=
micro_step
*
self
.
_micro_batch_size
end
=
begin
+
self
.
_micro_batch_size
if
isinstance
(
inputs
,
tuple
):
output
=
[]
for
data
in
inputs
:
if
isinstance
(
data
,
list
):
assert
(
len
(
data
)
==
self
.
_acc_steps
),
"length of data should be %d, but it is %d"
%
(
self
.
_acc_steps
,
len
(
data
),
)
output
.
append
(
data
[
micro_step
].
detach
())
elif
data
is
not
None
:
self
.
_check_data_vaild
(
data
)
output
.
append
(
data
[
begin
:
end
,
:].
detach
())
else
:
output
.
append
(
None
)
return
tuple
(
output
)
elif
isinstance
(
inputs
,
list
):
assert
(
len
(
inputs
)
==
self
.
_acc_steps
),
"length of data should be %d, but it is %d"
%
(
self
.
accumulate_steps
,
len
(
inputs
),
)
return
inputs
[
micro_step
].
detach
()
elif
inputs
is
not
None
:
self
.
_check_data_vaild
(
inputs
)
return
inputs
[
begin
:
end
,
:].
detach
()
else
:
return
None
def
_check_data_vaild
(
self
,
data
):
batch_size
=
data
.
shape
[
0
]
assert
self
.
_micro_batch_size
*
self
.
_acc_steps
==
batch_size
,
(
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
%
(
batch_size
,
self
.
_micro_batch_size
,
self
.
_acc_steps
)
)
class
PipelineParallel
(
MetaParallelBase
):
class
PipelineParallel
(
MetaParallelBase
):
def
__init__
(
self
,
layers
,
hcg
,
strategy
):
def
__init__
(
self
,
layers
,
hcg
,
strategy
):
if
not
isinstance
(
layers
,
PipelineLayer
):
if
not
isinstance
(
layers
,
PipelineLayer
):
...
@@ -237,9 +316,6 @@ class PipelineParallel(MetaParallelBase):
...
@@ -237,9 +316,6 @@ class PipelineParallel(MetaParallelBase):
self
.
scaler
=
scaler
self
.
scaler
=
scaler
# store data for train
self
.
data
=
data
# store total loss of entire batch
# store total loss of entire batch
self
.
total_loss
=
None
self
.
total_loss
=
None
...
@@ -253,10 +329,12 @@ class PipelineParallel(MetaParallelBase):
...
@@ -253,10 +329,12 @@ class PipelineParallel(MetaParallelBase):
input_buffers
=
[]
input_buffers
=
[]
output_buffers
=
[]
output_buffers
=
[]
micro_dataset
=
self
.
_wrap_data
(
data
)
for
step_id
in
range
(
startup_steps
):
for
step_id
in
range
(
startup_steps
):
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
output_tensor
=
self
.
_forward_step
(
input_tensor
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
)
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
input_buffers
.
append
(
input_tensor
)
input_buffers
.
append
(
input_tensor
)
...
@@ -271,7 +349,7 @@ class PipelineParallel(MetaParallelBase):
...
@@ -271,7 +349,7 @@ class PipelineParallel(MetaParallelBase):
for
i
in
range
(
steady_steps
):
for
i
in
range
(
steady_steps
):
last_iter
=
i
==
(
steady_steps
-
1
)
last_iter
=
i
==
(
steady_steps
-
1
)
output_tensor
=
self
.
_forward_step
(
input_tensor
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
)
output_tensor_grad
=
p2p
.
send_forward_recv_backward
(
output_tensor_grad
=
p2p
.
send_forward_recv_backward
(
output_tensor
,
self
.
is_pipeline_last_stage
()
output_tensor
,
self
.
is_pipeline_last_stage
()
...
@@ -365,6 +443,22 @@ class PipelineParallel(MetaParallelBase):
...
@@ -365,6 +443,22 @@ class PipelineParallel(MetaParallelBase):
return
data
return
data
def
_wrap_data
(
self
,
data
):
"""
for backward compatibilty, wrap data to Fake FakeMicroDataset if it is of type list or tuple
"""
if
(
not
isinstance
(
data
,
tuple
))
and
(
not
isinstance
(
data
,
list
)):
return
data
micro_dataset
=
FakeMicroDataset
(
data
,
self
.
is_pipeline_first_stage
(
ignore_virtual
=
True
),
self
.
is_pipeline_last_stage
(
ignore_virtual
=
True
),
self
.
accumulate_steps
,
self
.
micro_batch_size
,
)
return
micro_dataset
def
train_batch
(
self
,
data
,
optimizer
,
lr_scheduler
=
None
,
scaler
=
None
):
def
train_batch
(
self
,
data
,
optimizer
,
lr_scheduler
=
None
,
scaler
=
None
):
data
=
self
.
_prepare_training
(
data
,
optimizer
,
lr_scheduler
)
data
=
self
.
_prepare_training
(
data
,
optimizer
,
lr_scheduler
)
# 1f1b scheduler for pipeline parallel
# 1f1b scheduler for pipeline parallel
...
@@ -383,8 +477,6 @@ class PipelineParallel(MetaParallelBase):
...
@@ -383,8 +477,6 @@ class PipelineParallel(MetaParallelBase):
self
.
_layers
.
eval
()
self
.
_layers
.
eval
()
self
.
_compute_loss
=
compute_loss
self
.
_compute_loss
=
compute_loss
# save data for eval
self
.
data
=
data
# store data id for micro_batch
# store data id for micro_batch
self
.
micro_batch_id
=
0
self
.
micro_batch_id
=
0
...
@@ -398,10 +490,12 @@ class PipelineParallel(MetaParallelBase):
...
@@ -398,10 +490,12 @@ class PipelineParallel(MetaParallelBase):
input_buffers
=
[]
input_buffers
=
[]
output_buffers
=
[]
output_buffers
=
[]
micro_dataset
=
self
.
_wrap_data
(
data
)
for
step_id
in
range
(
startup_steps
):
for
step_id
in
range
(
startup_steps
):
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
output_tensor
=
self
.
_forward_step
(
input_tensor
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
)
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
input_buffers
.
append
(
input_tensor
)
input_buffers
.
append
(
input_tensor
)
...
@@ -413,7 +507,7 @@ class PipelineParallel(MetaParallelBase):
...
@@ -413,7 +507,7 @@ class PipelineParallel(MetaParallelBase):
for
i
in
range
(
steady_steps
):
for
i
in
range
(
steady_steps
):
last_iter
=
i
==
(
steady_steps
-
1
)
last_iter
=
i
==
(
steady_steps
-
1
)
output_tensor
=
self
.
_forward_step
(
input_tensor
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
)
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
input_buffers
.
append
(
input_tensor
)
input_buffers
.
append
(
input_tensor
)
...
@@ -429,11 +523,12 @@ class PipelineParallel(MetaParallelBase):
...
@@ -429,11 +523,12 @@ class PipelineParallel(MetaParallelBase):
return
self
.
train_loss
return
self
.
train_loss
def
_forward_step
(
self
,
input_tensor
,
chunk_id
=
None
):
def
_forward_step
(
self
,
input_tensor
,
micro_dataset
,
chunk_id
=
None
):
if
self
.
_enable_timer
:
if
self
.
_enable_timer
:
self
.
timers
(
"forward_step"
).
start
()
self
.
timers
(
"forward_step"
).
start
()
if
self
.
is_pipeline_first_stage
():
if
self
.
is_pipeline_first_stage
():
input_tensor
=
self
.
_load_micro_batch
(
self
.
micro_batch_id
)
input_tensor
=
next
(
micro_dataset
)[
0
]
self
.
_check_micro_batch_data_valid
(
input_tensor
)
assert
chunk_id
is
None
or
isinstance
(
chunk_id
,
int
)
assert
chunk_id
is
None
or
isinstance
(
chunk_id
,
int
)
...
@@ -445,7 +540,8 @@ class PipelineParallel(MetaParallelBase):
...
@@ -445,7 +540,8 @@ class PipelineParallel(MetaParallelBase):
assert
(
assert
(
self
.
_layers
.
_loss_fn
is
not
None
self
.
_layers
.
_loss_fn
is
not
None
),
"loss function should exist to compute loss"
),
"loss function should exist to compute loss"
labels
=
self
.
_load_micro_batch
(
self
.
micro_batch_id
)
labels
=
next
(
micro_dataset
)[
1
]
self
.
_check_micro_batch_data_valid
(
labels
)
output_tensor
=
self
.
_layers
.
_loss_fn
(
output_tensor
,
labels
)
output_tensor
=
self
.
_layers
.
_loss_fn
(
output_tensor
,
labels
)
assert
isinstance
(
assert
isinstance
(
output_tensor
,
(
paddle
.
Tensor
,
framework
.
core
.
eager
.
Tensor
)
output_tensor
,
(
paddle
.
Tensor
,
framework
.
core
.
eager
.
Tensor
)
...
@@ -467,6 +563,16 @@ class PipelineParallel(MetaParallelBase):
...
@@ -467,6 +563,16 @@ class PipelineParallel(MetaParallelBase):
self
.
timers
(
"forward_step"
).
stop
()
self
.
timers
(
"forward_step"
).
stop
()
return
output_tensor
return
output_tensor
def
_check_micro_batch_data_valid
(
self
,
micro_batch_data
):
if
isinstance
(
micro_batch_data
,
(
tuple
,
list
)):
for
data
in
micro_batch_data
:
self
.
_check_micro_batch_data_valid
(
data
)
elif
micro_batch_data
is
not
None
:
micro_batch_size
=
micro_batch_data
.
shape
[
0
]
assert
(
micro_batch_size
==
self
.
micro_batch_size
),
f
"expected micro_batch_size
{
self
.
micro_batch_size
}
but get
{
micro_batch_size
}
"
def
_backward_step
(
self
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
def
_backward_step
(
self
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
if
self
.
_enable_timer
:
if
self
.
_enable_timer
:
self
.
timers
(
"backward_step"
).
start
()
self
.
timers
(
"backward_step"
).
start
()
...
@@ -503,57 +609,6 @@ class PipelineParallel(MetaParallelBase):
...
@@ -503,57 +609,6 @@ class PipelineParallel(MetaParallelBase):
self
.
timers
(
"backward_step"
).
stop
()
self
.
timers
(
"backward_step"
).
stop
()
return
input_tensor_grad
return
input_tensor_grad
def
_check_data_vaild
(
self
,
data
):
batch_size
=
data
.
shape
[
0
]
assert
self
.
micro_batch_size
*
self
.
accumulate_steps
==
batch_size
,
(
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
%
(
batch_size
,
self
.
micro_batch_size
,
self
.
accumulate_steps
)
)
def
_load_micro_batch_impl
(
self
,
inputs
,
cache_id
):
begin
=
cache_id
*
self
.
micro_batch_size
end
=
begin
+
self
.
micro_batch_size
if
isinstance
(
inputs
,
tuple
):
output
=
[]
for
data
in
inputs
:
if
isinstance
(
data
,
list
):
assert
(
len
(
data
)
==
self
.
accumulate_steps
),
"length of data should be %d, but it is %d"
%
(
self
.
accumulate_steps
,
len
(
data
),
)
output
.
append
(
data
[
cache_id
].
detach
())
else
:
self
.
_check_data_vaild
(
data
)
output
.
append
(
data
[
begin
:
end
,
:].
detach
())
return
tuple
(
output
)
elif
isinstance
(
inputs
,
list
):
assert
(
len
(
inputs
)
==
self
.
accumulate_steps
),
"length of data should be %d, but it is %d"
%
(
self
.
accumulate_steps
,
len
(
inputs
),
)
return
inputs
[
cache_id
].
detach
()
else
:
self
.
_check_data_vaild
(
inputs
)
return
inputs
[
begin
:
end
,
:].
detach
()
def
_load_micro_batch
(
self
,
cache_id
):
inputs
=
self
.
data
if
self
.
is_pipeline_first_stage
():
assert
len
(
inputs
)
==
2
,
"length of input should be 2"
return
self
.
_load_micro_batch_impl
(
inputs
[
0
],
cache_id
)
elif
self
.
is_pipeline_last_stage
():
assert
len
(
inputs
)
==
2
,
"length of input should be 2"
return
self
.
_load_micro_batch_impl
(
inputs
[
1
],
cache_id
)
else
:
inputs
=
None
def
_broadcast_final_loss
(
self
):
def
_broadcast_final_loss
(
self
):
# Since the last backward run in interleave will set the virtual rank to 0,
# Since the last backward run in interleave will set the virtual rank to 0,
# here we need to check last stage ignoring virtual stage.
# here we need to check last stage ignoring virtual stage.
...
@@ -658,7 +713,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -658,7 +713,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
virtual_pp_stage
=
self
.
num_model_chunks
-
virtual_pp_stage
-
1
virtual_pp_stage
=
self
.
num_model_chunks
-
virtual_pp_stage
-
1
return
virtual_pp_stage
return
virtual_pp_stage
def
_forward_step_helper
(
self
,
micro_step
):
def
_forward_step_helper
(
self
,
micro_
dataset
,
micro_
step
):
virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
micro_step
,
forward
=
True
)
virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
micro_step
,
forward
=
True
)
self
.
set_virtual_pipeline_rank
(
virtual_pp_rank
)
self
.
set_virtual_pipeline_rank
(
virtual_pp_rank
)
...
@@ -674,7 +729,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -674,7 +729,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
):
):
self
.
input_tensors
[
virtual_pp_rank
].
append
(
None
)
self
.
input_tensors
[
virtual_pp_rank
].
append
(
None
)
input_tensor
=
self
.
input_tensors
[
virtual_pp_rank
][
-
1
]
input_tensor
=
self
.
input_tensors
[
virtual_pp_rank
][
-
1
]
output_tensor
=
self
.
_forward_step
(
input_tensor
,
virtual_pp_rank
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
,
virtual_pp_rank
)
self
.
output_tensors
[
virtual_pp_rank
].
append
(
output_tensor
)
self
.
output_tensors
[
virtual_pp_rank
].
append
(
output_tensor
)
if
self
.
_forward_only
:
if
self
.
_forward_only
:
...
@@ -719,7 +776,6 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -719,7 +776,6 @@ class PipelineParallelWithInterleave(PipelineParallel):
# init some attributes for this batch run
# init some attributes for this batch run
self
.
scaler
=
scaler
self
.
scaler
=
scaler
self
.
data
=
data
self
.
total_loss
=
None
self
.
total_loss
=
None
self
.
micro_batch_id
=
0
self
.
micro_batch_id
=
0
self
.
_forward_only
=
forward_only
self
.
_forward_only
=
forward_only
...
@@ -729,6 +785,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -729,6 +785,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
self
.
output_tensors
=
[[]
for
_
in
range
(
self
.
num_model_chunks
)]
self
.
output_tensors
=
[[]
for
_
in
range
(
self
.
num_model_chunks
)]
self
.
output_tensor_grads
=
[[]
for
_
in
range
(
self
.
num_model_chunks
)]
self
.
output_tensor_grads
=
[[]
for
_
in
range
(
self
.
num_model_chunks
)]
micro_dataset
=
self
.
_wrap_data
(
data
)
num_steps
=
self
.
accumulate_steps
*
self
.
num_model_chunks
num_steps
=
self
.
accumulate_steps
*
self
.
num_model_chunks
all_startup_steps
=
False
all_startup_steps
=
False
if
forward_only
:
if
forward_only
:
...
@@ -752,7 +810,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -752,7 +810,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
# run startup steps
# run startup steps
for
micro_step
in
range
(
startup_steps
):
for
micro_step
in
range
(
startup_steps
):
output_tensor
=
self
.
_forward_step_helper
(
micro_step
)
output_tensor
=
self
.
_forward_step_helper
(
micro_
dataset
,
micro_
step
)
# determine whether recv forward tensor or not
# determine whether recv forward tensor or not
next_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
next_virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
...
@@ -806,7 +864,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -806,7 +864,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
for
micro_step
in
range
(
steady_steps
):
for
micro_step
in
range
(
steady_steps
):
# forward
# forward
forward_micro_step_id
=
micro_step
+
startup_steps
forward_micro_step_id
=
micro_step
+
startup_steps
output_tensor
=
self
.
_forward_step_helper
(
forward_micro_step_id
)
output_tensor
=
self
.
_forward_step_helper
(
micro_dataset
,
forward_micro_step_id
)
# backward
# backward
backward_micro_step_id
=
micro_step
backward_micro_step_id
=
micro_step
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录