Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c20eb7a6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c20eb7a6
编写于
11月 17, 2022
作者:
W
wuhuachaocoding
提交者:
GitHub
11月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support stage2 for gradient merge. (#47711)
上级
460d5040
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
29 addition
and
39 deletion
+29
-39
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py
...uted/fleet/meta_parallel/sharding/group_sharded_stage2.py
+28
-22
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py
...buted/fleet/meta_parallel/sharding/group_sharded_utils.py
+1
-12
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py
...tests/collective/fleet/dygraph_group_sharded_api_eager.py
+0
-5
未找到文件。
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py
浏览文件 @
c20eb7a6
...
...
@@ -418,17 +418,6 @@ class GroupShardedStage2(nn.Layer):
)
)
if
self
.
_dp_group
and
self
.
_dp_group
.
nranks
>
1
:
assert
(
not
self
.
_reduce_overlap
),
'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
# TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
dist
.
all_reduce
(
tensor
=
param
.
grad
,
group
=
self
.
_dp_group
,
sync_op
=
True
,
)
# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()
...
...
@@ -485,17 +474,6 @@ class GroupShardedStage2(nn.Layer):
)
)
if
self
.
_dp_group
and
self
.
_dp_group
.
nranks
>
1
:
assert
(
not
self
.
_reduce_overlap
),
'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
# TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
dist
.
all_reduce
(
tensor
=
grad_storage
.
buffer
,
group
=
self
.
_dp_group
,
sync_op
=
True
,
)
cleanup
()
# Clear the task flow and trigger callback to clear the redundant gradient
...
...
@@ -648,8 +626,34 @@ class GroupShardedStage2(nn.Layer):
)
return
rank_buffer_size
def
_dp_allreduce
(
self
):
# do dp allreduce here for gradient merge.
if
self
.
_dp_group
and
self
.
_dp_group
.
nranks
>
1
:
for
dtype
in
self
.
_grad_storages
.
keys
():
for
rank
,
g
in
sorted
(
self
.
_grad_storages
[
dtype
].
items
(),
key
=
lambda
x
:
x
[
0
]
):
if
g
.
destination
==
self
.
_rank
:
assert
g
.
buffer
.
_is_initialized
()
dist
.
all_reduce
(
tensor
=
g
.
buffer
,
group
=
self
.
_dp_group
,
sync_op
=
True
,
)
for
param
in
self
.
_trainable_params
:
if
param
.
name
in
self
.
_param_grads
and
param
.
grad
is
not
None
:
dst_rank
=
self
.
_trainable_param2rank
[
param
.
name
]
if
dst_rank
==
self
.
_rank
:
dist
.
all_reduce
(
tensor
=
param
.
grad
,
group
=
self
.
_dp_group
,
sync_op
=
True
,
)
def
_redefine_opt_step
(
self
):
grad_func
=
self
.
_grad_scale
dp_allreduce_func
=
self
.
_dp_allreduce
for
opt
in
self
.
_sharding_optimizers
:
opt_step
=
opt
.
step
...
...
@@ -658,7 +662,9 @@ class GroupShardedStage2(nn.Layer):
# Wait for the last reduce task. This wait must before grad scale function.
assert
self
.
_comm_task
is
not
None
self
.
_comm_task
.
wait
()
grad_func
()
dp_allreduce_func
()
opt_step
()
opt
.
step
=
MethodType
(
_opt_step
,
opt
)
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py
浏览文件 @
c20eb7a6
...
...
@@ -23,7 +23,6 @@ from paddle.fluid import core
from
paddle.fluid
import
layers
from
paddle.fluid.dygraph
import
to_variable
from
paddle.fluid.framework
import
dygraph_only
from
paddle.distributed
import
fleet
,
ParallelMode
class
Taskflow
:
...
...
@@ -245,18 +244,8 @@ def GroupShardedScaler(scaler):
self
.
_found_inf
=
1
if
temp_found_inf_fp16
or
temp_found_inf_fp32
else
0
is_found_inf
=
paddle
.
to_tensor
([
self
.
_found_inf
],
dtype
=
"int32"
)
hcg
=
fleet
.
fleet
.
_hcg
if
hasattr
(
fleet
.
fleet
,
"_hcg"
)
else
None
hybrid_parallel
=
(
hcg
is
not
None
and
hcg
.
get_parallel_mode
()
is
not
ParallelMode
.
DATA_PARALLEL
)
paddle
.
distributed
.
all_reduce
(
is_found_inf
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
,
group
=
hcg
.
get_check_parallel_group
()
if
hybrid_parallel
else
optimizer
.
_group
,
is_found_inf
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
,
group
=
None
)
self
.
_found_inf
=
is_found_inf
.
numpy
()[
0
]
...
...
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py
浏览文件 @
c20eb7a6
...
...
@@ -148,11 +148,6 @@ def test_sharding_api():
output_dir
=
tempfile
.
mkdtemp
()
# test sharding + dp, just for test
dp_group
=
paddle
.
distributed
.
new_group
(
list
(
range
(
paddle
.
distributed
.
get_world_size
()))
)
# fp16
stage2_params
=
train_mlp
(
mlp1
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录