Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d8b4ca92
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d8b4ca92
编写于
10月 09, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
10月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[dygraph sharding stage 2] sharding broadcast overlap (#46656)
上级
9a849a37
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
102 addition
and
20 deletion
+102
-20
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py
.../meta_parallel/sharding/group_sharded_optimizer_stage2.py
+88
-10
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py
...uted/fleet/meta_parallel/sharding/group_sharded_stage2.py
+9
-9
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py
...ective/fleet/dygraph_group_sharded_stage2_comm_overlap.py
+5
-1
未找到文件。
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py
浏览文件 @
d8b4ca92
...
@@ -24,6 +24,8 @@
...
@@ -24,6 +24,8 @@
import
copy
import
copy
import
logging
import
logging
import
warnings
import
numpy
as
np
import
numpy
as
np
from
collections
import
OrderedDict
from
collections
import
OrderedDict
...
@@ -87,7 +89,7 @@ class GroupShardedOptimizerStage2(Optimizer):
...
@@ -87,7 +89,7 @@ class GroupShardedOptimizerStage2(Optimizer):
self
.
_optim
=
optim
self
.
_optim
=
optim
# sharing stage 2 comm overlap flag
# sharing stage 2 comm overlap flag
self
.
_
comm
_overlap
=
False
self
.
_
reduce
_overlap
=
False
# record the last task used for comm overlap for sharding stage 2
# record the last task used for comm overlap for sharding stage 2
self
.
_comm_task
=
None
self
.
_comm_task
=
None
...
@@ -108,6 +110,17 @@ class GroupShardedOptimizerStage2(Optimizer):
...
@@ -108,6 +110,17 @@ class GroupShardedOptimizerStage2(Optimizer):
filter
(
lambda
x
:
x
.
trainable
and
x
.
dtype
==
Type
.
fp16
.
value
,
filter
(
lambda
x
:
x
.
trainable
and
x
.
dtype
==
Type
.
fp16
.
value
,
self
.
_local_params
)))
>
0
self
.
_local_params
)))
>
0
self
.
_broadcast_overlap
=
False
self
.
_forward_pre_hook_remove_helper
=
[]
try
:
# The fp32 params such as layer_norm_0.w_0 will be at the end of param_list.
# Have to sort the params to make sure all params are in the forward using order.
self
.
_broadcast_order_params
=
sorted
(
self
.
local_params
,
key
=
lambda
x
:
int
(
x
.
name
.
split
(
'.'
)[
0
].
split
(
'_'
)[
-
1
]))
except
ValueError
:
self
.
_broadcast_order_params
=
None
self
.
_group
=
new_group
(
self
.
_group
=
new_group
(
_get_global_group
().
ranks
)
if
group
is
None
else
group
_get_global_group
().
ranks
)
if
group
is
None
else
group
...
@@ -163,15 +176,34 @@ class GroupShardedOptimizerStage2(Optimizer):
...
@@ -163,15 +176,34 @@ class GroupShardedOptimizerStage2(Optimizer):
sync_op
=
True
)
sync_op
=
True
)
def
_update_task
(
self
,
task
):
def
_update_task
(
self
,
task
):
if
self
.
_
comm
_overlap
:
if
self
.
_
reduce
_overlap
:
assert
task
is
not
None
assert
task
is
not
None
# Only track of the last reduce task.
# Only track of the last reduce task.
# Since all tasks are on the same stream, only need to wait the last one.
# Since all tasks are on the same stream, only need to wait the last one.
# After waiting for the last reduce task, all reduce tasks before have already finished.
# After waiting for the last reduce task, all reduce tasks before have already finished.
self
.
_comm_task
=
task
self
.
_comm_task
=
task
def
_set_comm_overlap
(
self
,
comm_overlap
):
def
_set_reduce_overlap
(
self
,
reduce_overlap
):
self
.
_comm_overlap
=
comm_overlap
# Enable gradients' reduces overlap with backward calculation.
self
.
_reduce_overlap
=
reduce_overlap
def
_set_broadcast_overlap
(
self
,
broadcast_overlap
,
layers
=
None
):
# Enable post optimizer broadcasts overlap with the forward calculation of next batch.
self
.
_broadcast_overlap
=
broadcast_overlap
if
self
.
_broadcast_overlap
:
assert
layers
is
not
None
,
\
"To enable broadcast overlap forward, please pass the module to the function."
self
.
_layers
=
layers
warnings
.
warn
(
"Setting overlap broadcast means the `paddle.device.cuda.synchronize()` "
"must be called manually before calling `paddle.save()` and before and inference."
)
if
self
.
_broadcast_order_params
is
None
:
# Params' names should be like column_linear_32.w_0 patter to get the best performance.
warnings
.
warn
(
"The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, "
"overlap broadcast may harm the performance."
)
self
.
_broadcast_order_params
=
self
.
_local_params
def
_generate_master_params
(
self
,
trainable_params
):
def
_generate_master_params
(
self
,
trainable_params
):
if
self
.
offload
:
if
self
.
offload
:
...
@@ -382,6 +414,12 @@ class GroupShardedOptimizerStage2(Optimizer):
...
@@ -382,6 +414,12 @@ class GroupShardedOptimizerStage2(Optimizer):
"""
"""
# This method won't be called directly by opt.step()!
# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
if
self
.
_broadcast_overlap
:
# Clear the pre forward hook in the optimizer step.
for
hook_remove
in
self
.
_forward_pre_hook_remove_helper
:
hook_remove
.
remove
()
self
.
_forward_pre_hook_remove_helper
=
[]
if
self
.
offload
:
if
self
.
offload
:
params_list
=
[
self
.
offload_params
.
buffer
]
params_list
=
[
self
.
offload_params
.
buffer
]
...
@@ -425,9 +463,49 @@ class GroupShardedOptimizerStage2(Optimizer):
...
@@ -425,9 +463,49 @@ class GroupShardedOptimizerStage2(Optimizer):
"""Broadcast the parameters of the current rank to each rank"""
"""Broadcast the parameters of the current rank to each rank"""
# Exchange all the shards with the other ranks
# Exchange all the shards with the other ranks
for
dtype_per_rank
in
self
.
param_storages
.
values
():
if
self
.
_broadcast_overlap
:
for
dst_rank
,
internal_storage
in
dtype_per_rank
.
items
():
self
.
_broadcast_params_overlap_forward
()
broadcast
(
tensor
=
internal_storage
.
buffer
,
else
:
src
=
self
.
_group
.
ranks
[
dst_rank
],
for
dtype_per_rank
in
self
.
param_storages
.
values
():
group
=
self
.
_group
,
for
dst_rank
,
internal_storage
in
dtype_per_rank
.
items
():
sync_op
=
True
)
broadcast
(
tensor
=
internal_storage
.
buffer
,
src
=
self
.
_group
.
ranks
[
dst_rank
],
group
=
self
.
_group
,
sync_op
=
True
)
def
_forward_pre_hook_function
(
self
,
tasks
):
# Since the layers will call pre hook by `forward_pre_hook(self, inputs)`,
# the helper functions needs the x and y to take those params.
def
__impl__
(
x
,
y
):
for
task
in
tasks
:
# Wait for broadcast task before using the result of the broadcast.
task
.
wait
()
return
__impl__
@
paddle
.
autograd
.
no_grad
()
def
_broadcast_params_overlap_forward
(
self
):
# Exchange all the shards with the other ranks,
# but overlap the broadcast with next batch's calculation.
param2task
=
{}
for
x
in
self
.
_broadcast_order_params
:
if
x
.
trainable
:
task
=
broadcast
(
tensor
=
x
,
src
=
self
.
_group
.
ranks
[
self
.
_param2rank
[
x
.
name
]],
group
=
self
.
_group
,
sync_op
=
False
)
assert
x
.
name
not
in
param2task
param2task
[
x
.
name
]
=
task
for
layer
in
self
.
_layers
.
sublayers
():
if
len
(
layer
.
sublayers
())
==
0
:
# Register forward pre hood for leaf layers. This will get the best performance.
tasks
=
[]
for
param
in
layer
.
parameters
():
if
param
.
trainable
:
if
param
.
name
in
param2task
:
tasks
.
append
(
param2task
[
param
.
name
])
self
.
_forward_pre_hook_remove_helper
.
append
(
layer
.
register_forward_pre_hook
(
self
.
_forward_pre_hook_function
(
tasks
)))
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py
浏览文件 @
d8b4ca92
...
@@ -101,7 +101,7 @@ class GroupShardedStage2(nn.Layer):
...
@@ -101,7 +101,7 @@ class GroupShardedStage2(nn.Layer):
self
.
_all_params
.
extend
(
list
(
optim
.
local_params
))
self
.
_all_params
.
extend
(
list
(
optim
.
local_params
))
# sharing stage 2 comm overlap flag
# sharing stage 2 comm overlap flag
self
.
_
comm
_overlap
=
False
self
.
_
reduce
_overlap
=
False
self
.
_trainable_params
=
[]
self
.
_trainable_params
=
[]
self
.
_grad_reduced
=
[]
self
.
_grad_reduced
=
[]
...
@@ -309,17 +309,17 @@ class GroupShardedStage2(nn.Layer):
...
@@ -309,17 +309,17 @@ class GroupShardedStage2(nn.Layer):
for
grad_storage
in
self
.
_grad_storage_list
:
for
grad_storage
in
self
.
_grad_storage_list
:
grad_storage
.
reset_checked_in
()
grad_storage
.
reset_checked_in
()
def
_set_
comm_overlap
(
self
,
comm
_overlap
):
def
_set_
reduce_overlap
(
self
,
reduce
_overlap
):
# Hacky way to not add an extra parameter to the `group_sharded_parallel` funct.
# Hacky way to not add an extra parameter to the `group_sharded_parallel` funct.
# User should use this like:
# User should use this like:
# model, optimizer, scaler = group_sharded_parallel(...)
# model, optimizer, scaler = group_sharded_parallel(...)
# model._set_
comm
_overlap(True)
# model._set_
reduce
_overlap(True)
self
.
_
comm_overlap
=
comm
_overlap
self
.
_
reduce_overlap
=
reduce
_overlap
if
self
.
_
comm
_overlap
:
if
self
.
_
reduce
_overlap
:
assert
len
(
assert
len
(
self
.
_sharding_optimizers
self
.
_sharding_optimizers
)
==
1
,
"Only support comm overlap strategy for single optimizer"
)
==
1
,
"Only support comm overlap strategy for single optimizer"
self
.
_sharding_optimizers
[
0
].
_set_
comm_overlap
(
comm
_overlap
)
self
.
_sharding_optimizers
[
0
].
_set_
reduce_overlap
(
reduce
_overlap
)
def
_get_reduce_fn
(
self
,
index
,
param
,
dst_rank
):
def
_get_reduce_fn
(
self
,
index
,
param
,
dst_rank
):
"""
"""
...
@@ -357,7 +357,7 @@ class GroupShardedStage2(nn.Layer):
...
@@ -357,7 +357,7 @@ class GroupShardedStage2(nn.Layer):
collective
.
reduce
(
tensor
=
param
.
grad
,
collective
.
reduce
(
tensor
=
param
.
grad
,
dst
=
self
.
_group
.
ranks
[
dst_rank
],
dst
=
self
.
_group
.
ranks
[
dst_rank
],
group
=
self
.
_group
,
group
=
self
.
_group
,
sync_op
=
not
self
.
_
comm
_overlap
))
sync_op
=
not
self
.
_
reduce
_overlap
))
# Clear the task flow and trigger callback to clear the redundant gradient
# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()
# self._clear_task_flow()
...
@@ -407,7 +407,7 @@ class GroupShardedStage2(nn.Layer):
...
@@ -407,7 +407,7 @@ class GroupShardedStage2(nn.Layer):
tensor
=
grad_storage
.
buffer
,
tensor
=
grad_storage
.
buffer
,
dst
=
self
.
_group
.
ranks
[
grad_storage
.
destination
],
dst
=
self
.
_group
.
ranks
[
grad_storage
.
destination
],
group
=
self
.
_group
,
group
=
self
.
_group
,
sync_op
=
not
self
.
_
comm
_overlap
))
sync_op
=
not
self
.
_
reduce
_overlap
))
cleanup
()
cleanup
()
...
@@ -545,7 +545,7 @@ class GroupShardedStage2(nn.Layer):
...
@@ -545,7 +545,7 @@ class GroupShardedStage2(nn.Layer):
opt_step
=
opt
.
step
opt_step
=
opt
.
step
def
_opt_step
(
self
):
def
_opt_step
(
self
):
if
self
.
_
comm
_overlap
:
if
self
.
_
reduce
_overlap
:
# Wait for the last reduce task. This wait must before grad scale function.
# Wait for the last reduce task. This wait must before grad scale function.
assert
self
.
_comm_task
is
not
None
assert
self
.
_comm_task
is
not
None
self
.
_comm_task
.
wait
()
self
.
_comm_task
.
wait
()
...
...
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py
浏览文件 @
d8b4ca92
...
@@ -92,13 +92,15 @@ def train_mlp(model,
...
@@ -92,13 +92,15 @@ def train_mlp(model,
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
)
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
)
if
sharding_stage
==
2
:
if
sharding_stage
==
2
:
origin_model
=
model
optimizer
=
GroupShardedOptimizerStage2
(
optimizer
=
GroupShardedOptimizerStage2
(
params
=
optimizer
.
_parameter_list
,
optim
=
optimizer
,
group
=
group
)
params
=
optimizer
.
_parameter_list
,
optim
=
optimizer
,
group
=
group
)
model
=
GroupShardedStage2
(
model
,
model
=
GroupShardedStage2
(
model
,
optimizer
,
optimizer
,
group
=
group
,
group
=
group
,
buffer_max_size
=
2
**
21
)
buffer_max_size
=
2
**
21
)
model
.
_set_comm_overlap
(
True
)
model
.
_set_reduce_overlap
(
True
)
optimizer
.
_set_broadcast_overlap
(
True
,
model
)
else
:
else
:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
...
@@ -149,6 +151,8 @@ def train_mlp(model,
...
@@ -149,6 +151,8 @@ def train_mlp(model,
optimizer
.
step
()
optimizer
.
step
()
optimizer
.
clear_grad
()
optimizer
.
clear_grad
()
paddle
.
device
.
cuda
.
synchronize
()
if
save_model
:
if
save_model
:
return
model
,
optimizer
return
model
,
optimizer
return
model
.
parameters
()
return
model
.
parameters
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录