Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
fa878846
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fa878846
编写于
8月 09, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
8月 09, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cherry pick #55651 and #55890 (#56063)
上级
0d920178
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
90 addition
and
19 deletion
+90
-19
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+79
-13
python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
...on/paddle/distributed/fleet/utils/tensor_fusion_helper.py
+11
-6
未找到文件。
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
fa878846
...
@@ -11,8 +11,10 @@
...
@@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
import
os
import
os
import
sys
import
time
import
time
import
warnings
import
warnings
from
collections
import
defaultdict
import
paddle
import
paddle
from
paddle
import
framework
from
paddle
import
framework
...
@@ -217,7 +219,7 @@ class PipelineParallel(MetaParallelBase):
...
@@ -217,7 +219,7 @@ class PipelineParallel(MetaParallelBase):
self
.
_dp_comm_overlap
and
self
.
_sharding_comm_overlap
self
.
_dp_comm_overlap
and
self
.
_sharding_comm_overlap
),
"Cannot use dp pp overlap and sharding pp overlap at the same time."
),
"Cannot use dp pp overlap and sharding pp overlap at the same time."
self
.
_c
omm_buffers
=
[]
self
.
_c
hunk_2_comm_buffers
=
defaultdict
(
list
)
self
.
_comm_overlap
=
(
self
.
_comm_overlap
=
(
self
.
_dp_comm_overlap
or
self
.
_sharding_comm_overlap
self
.
_dp_comm_overlap
or
self
.
_sharding_comm_overlap
)
)
...
@@ -291,7 +293,9 @@ class PipelineParallel(MetaParallelBase):
...
@@ -291,7 +293,9 @@ class PipelineParallel(MetaParallelBase):
return
fused_allreduce
return
fused_allreduce
def
register_allreduce_overlap_hook
(
self
,
model
,
comm_group
,
acc_steps
,
dp
):
def
register_allreduce_overlap_hook
(
self
,
model
,
comm_group
,
acc_steps
,
dp
,
group_size
=
128
*
1024
*
1024
):
if
model
.
get_num_virtual_stages
()
>
1
:
if
model
.
get_num_virtual_stages
()
>
1
:
models
=
model
.
get_model_chunks
()
models
=
model
.
get_model_chunks
()
else
:
else
:
...
@@ -308,7 +312,7 @@ class PipelineParallel(MetaParallelBase):
...
@@ -308,7 +312,7 @@ class PipelineParallel(MetaParallelBase):
else
HOOK_ACTION
.
REDUCE
else
HOOK_ACTION
.
REDUCE
)
)
for
model
in
models
:
for
chunk_idx
,
model
in
enumerate
(
models
)
:
# For virtual pipeline. Will separate parameters in different chunk into
# For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance.
# different groups to get the best performance.
...
@@ -338,12 +342,12 @@ class PipelineParallel(MetaParallelBase):
...
@@ -338,12 +342,12 @@ class PipelineParallel(MetaParallelBase):
dst
=
comm_group
.
ranks
[
dst
]
dst
=
comm_group
.
ranks
[
dst
]
else
:
else
:
dst
=
-
1
dst
=
-
1
var_groups
=
assign_group_by_size
(
parameter_list
)
var_groups
=
assign_group_by_size
(
parameter_list
,
group_size
)
for
group_idx
,
parameters
in
var_groups
.
items
():
for
group_idx
,
parameters
in
var_groups
.
items
():
buffer
=
FusedCommBuffer
(
buffer
=
FusedCommBuffer
(
group_idx
,
parameters
,
comm_group
,
acc_steps
,
act
,
dst
group_idx
,
parameters
,
comm_group
,
acc_steps
,
act
,
dst
)
)
self
.
_c
omm_buffers
.
append
(
buffer
)
self
.
_c
hunk_2_comm_buffers
[
chunk_idx
]
.
append
(
buffer
)
for
param
in
parameters
:
for
param
in
parameters
:
param
.
_register_backward_hook
(
param
.
_register_backward_hook
(
self
.
bw_hook_func
(
buffer
,
param
)
self
.
bw_hook_func
(
buffer
,
param
)
...
@@ -514,9 +518,12 @@ class PipelineParallel(MetaParallelBase):
...
@@ -514,9 +518,12 @@ class PipelineParallel(MetaParallelBase):
self
.
_flush_records
()
self
.
_flush_records
()
if
self
.
_comm_overlap
:
if
self
.
_comm_overlap
:
assert
len
(
self
.
_comm_buffers
)
>
0
assert
(
for
buffer
in
self
.
_comm_buffers
:
len
(
self
.
_chunk_2_comm_buffers
)
>
0
buffer
.
scale_grads
()
),
"comm buffers should be created"
for
_
,
buffers
in
self
.
_chunk_2_comm_buffers
.
items
():
for
buffer
in
buffers
:
buffer
.
scale_grads
()
if
self
.
_enable_timer
:
if
self
.
_enable_timer
:
self
.
timers
(
"allreduce_shared_weight_gradients"
).
start
()
self
.
timers
(
"allreduce_shared_weight_gradients"
).
start
()
...
@@ -557,7 +564,7 @@ class PipelineParallel(MetaParallelBase):
...
@@ -557,7 +564,7 @@ class PipelineParallel(MetaParallelBase):
self
.
_layers
.
train
()
self
.
_layers
.
train
()
if
self
.
_sharding_comm_overlap
and
len
(
self
.
_comm_buffers
)
==
0
:
if
self
.
_sharding_comm_overlap
and
len
(
self
.
_c
hunk_2_c
omm_buffers
)
==
0
:
self
.
register_allreduce_overlap_hook
(
self
.
register_allreduce_overlap_hook
(
self
.
_layers
,
self
.
sharding_group
,
self
.
accumulate_steps
,
False
self
.
_layers
,
self
.
sharding_group
,
self
.
accumulate_steps
,
False
)
)
...
@@ -932,6 +939,39 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -932,6 +939,39 @@ class PipelineParallelWithInterleave(PipelineParallel):
return
output_tensor
return
output_tensor
def
_overlap_comm_grads
(
self
):
if
self
.
_comm_overlap
:
self
.
_backward_step_count
+=
1
sync_step
=
self
.
_backward_step_count
-
self
.
stage_id
if
sync_step
>
0
and
sync_step
%
self
.
num_stages
==
0
:
chunk_idx
=
self
.
_virtual_pp_world_size
-
(
sync_step
//
self
.
num_stages
)
for
buffer
in
self
.
_chunk_2_comm_buffers
[
chunk_idx
]:
buffer
.
comm_grads
()
if
self
.
stage_id
!=
0
:
if
(
self
.
_backward_step_count
==
self
.
num_stages
*
self
.
num_model_chunks
):
for
buffer
in
self
.
_chunk_2_comm_buffers
[
0
]:
buffer
.
comm_grads
()
def
_sync_overlap_grads
(
self
):
if
self
.
_comm_overlap
:
assert
(
self
.
_backward_step_count
==
self
.
num_stages
*
self
.
num_model_chunks
),
(
"backward step count should be equal to accumulate steps * virtual pp world size,"
f
" but get
{
self
.
_backward_step_count
}
, excepted result is
{
self
.
num_stages
*
self
.
num_model_chunks
}
"
)
for
_
,
buffers
in
self
.
_chunk_2_comm_buffers
.
items
():
for
buffer
in
buffers
:
buffer
.
scale_grads
()
def
_backward_step_helper
(
self
,
micro_step
):
def
_backward_step_helper
(
self
,
micro_step
):
virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
micro_step
,
forward
=
False
)
virtual_pp_rank
=
self
.
_get_virtual_pp_rank
(
micro_step
,
forward
=
False
)
self
.
set_virtual_pipeline_rank
(
virtual_pp_rank
)
self
.
set_virtual_pipeline_rank
(
virtual_pp_rank
)
...
@@ -955,8 +995,24 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -955,8 +995,24 @@ class PipelineParallelWithInterleave(PipelineParallel):
input_tensor
,
output_tensor
,
output_tensor_grad
input_tensor
,
output_tensor
,
output_tensor_grad
)
)
self
.
_overlap_comm_grads
()
return
input_tensor_grad
return
input_tensor_grad
def
bw_hook_func
(
self
,
buffer
,
param
):
# For pipeline with interleave, we need to add grad to buffer without communication.
# Use communication where appropriate to avoid dp communication and pp scheduling conflicts.
@
paddle
.
autograd
.
no_grad
()
def
fused_allreduce
(
*
_
):
buffer
.
add_grad
(
param
,
use_comm
=
False
)
return
fused_allreduce
def
register_allreduce_overlap_hook
(
self
,
model
,
comm_group
,
acc_steps
,
dp
):
super
().
register_allreduce_overlap_hook
(
model
,
comm_group
,
acc_steps
,
dp
,
group_size
=
sys
.
maxsize
)
def
forward_backward_pipeline
(
def
forward_backward_pipeline
(
self
,
self
,
data
,
data
,
...
@@ -995,6 +1051,19 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -995,6 +1051,19 @@ class PipelineParallelWithInterleave(PipelineParallel):
self
.
micro_batch_id
=
0
self
.
micro_batch_id
=
0
self
.
_forward_only
=
forward_only
self
.
_forward_only
=
forward_only
# store the number of backward steps
assert
(
self
.
accumulate_steps
%
self
.
num_stages
==
0
),
"accumulate_steps({}) should be evenly divisible by num_stages({}) for pipeline with interleave"
.
format
(
self
.
accumulate_steps
,
self
.
num_stages
)
per_stage_accumulate_steps
=
self
.
accumulate_steps
//
self
.
num_stages
self
.
_backward_step_count
=
(
-
(
per_stage_accumulate_steps
-
1
)
*
self
.
num_stages
*
self
.
num_model_chunks
)
# init some data buffers for interleave scheduler
# init some data buffers for interleave scheduler
self
.
input_tensors
=
[[]
for
_
in
range
(
self
.
num_model_chunks
)]
self
.
input_tensors
=
[[]
for
_
in
range
(
self
.
num_model_chunks
)]
self
.
output_tensors
=
[[]
for
_
in
range
(
self
.
num_model_chunks
)]
self
.
output_tensors
=
[[]
for
_
in
range
(
self
.
num_model_chunks
)]
...
@@ -1254,10 +1323,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
...
@@ -1254,10 +1323,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
)
)
)
)
if
self
.
_comm_overlap
:
self
.
_sync_overlap_grads
()
assert
len
(
self
.
_comm_buffers
)
>
0
for
buffer
in
self
.
_comm_buffers
:
buffer
.
scale_grads
()
if
static_scheduler
:
if
static_scheduler
:
self
.
_reset_counter
()
self
.
_reset_counter
()
...
...
python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
浏览文件 @
fa878846
...
@@ -206,7 +206,7 @@ class FusedCommBuffer:
...
@@ -206,7 +206,7 @@ class FusedCommBuffer:
and
len
(
self
.
_params_step_dict
)
==
0
and
len
(
self
.
_params_step_dict
)
==
0
)
)
def
add_grad
(
self
,
param
):
def
add_grad
(
self
,
param
,
use_comm
=
True
):
assert
param
.
name
in
self
.
_params_step_dict
assert
param
.
name
in
self
.
_params_step_dict
current_ptr
=
(
current_ptr
=
(
param
.
main_grad
.
data_ptr
()
param
.
main_grad
.
data_ptr
()
...
@@ -227,12 +227,17 @@ class FusedCommBuffer:
...
@@ -227,12 +227,17 @@ class FusedCommBuffer:
self
.
_params_checked_in
+=
1
self
.
_params_checked_in
+=
1
self
.
_params_step_dict
.
pop
(
param
.
name
)
self
.
_params_step_dict
.
pop
(
param
.
name
)
if
self
.
_all_params_checked_in
:
if
self
.
_all_params_checked_in
and
use_comm
:
self
.
_
comm_grads
()
self
.
comm_grads
()
@
imperative_base
.
no_grad
@
imperative_base
.
no_grad
def
_comm_grads
(
self
):
def
comm_grads
(
self
):
assert
self
.
_all_params_checked_in
assert
self
.
_all_params_checked_in
,
(
"Not all params checked in."
"Parameter number: {}, Check-in number: {}"
.
format
(
len
(
self
.
_params
),
self
.
_params_checked_in
)
)
if
not
self
.
_scale_after_comm
:
if
not
self
.
_scale_after_comm
:
scale_factor
=
1.0
/
self
.
_comm_group
.
nranks
scale_factor
=
1.0
/
self
.
_comm_group
.
nranks
...
@@ -255,7 +260,7 @@ class FusedCommBuffer:
...
@@ -255,7 +260,7 @@ class FusedCommBuffer:
@
imperative_base
.
no_grad
@
imperative_base
.
no_grad
def
scale_grads
(
self
):
def
scale_grads
(
self
):
assert
self
.
_task
is
not
None
assert
self
.
_task
is
not
None
,
"Task is not initialized."
self
.
_task
.
wait
()
self
.
_task
.
wait
()
if
self
.
_scale_after_comm
:
if
self
.
_scale_after_comm
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录