Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c036c5c0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c036c5c0
编写于
10月 28, 2022
作者:
S
sneaxiy
提交者:
GitHub
10月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add fused_allreduce_gradients_with_group for PPFleetX (#47447)
* add fused_allreduce_gradients_with_group * add scale * fix ci
上级
17fb92b3
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
34 addition
and
16 deletion
+34
-16
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
...on/paddle/distributed/fleet/utils/hybrid_parallel_util.py
+34
-16
未找到文件。
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
浏览文件 @
c036c5c0
...
...
@@ -26,7 +26,7 @@ from .log_util import logger
__all__
=
[]
def
_apply_collective_grads
(
parameters
,
comm_group
):
def
_apply_collective_grads
(
parameters
,
comm_group
,
bucket_size
,
scale
=
None
):
grad_var_set
=
set
()
grad_vars
=
[]
sparse_grad_vars
=
[]
...
...
@@ -41,28 +41,35 @@ def _apply_collective_grads(parameters, comm_group):
assert
g_var
not
in
grad_var_set
grad_var_set
.
add
(
g_var
)
coalesced_grads_and_vars
=
build_groups
(
grad_vars
,
128
*
1024
*
1024
)
coalesced_grads_and_vars
=
build_groups
(
grad_vars
,
bucket_size
)
nranks
=
(
paddle
.
distributed
.
get_world_size
()
if
comm_group
is
None
else
comm_group
.
nranks
)
scale
=
nranks
if
scale
is
None
else
1.0
/
scale
scale
=
None
if
scale
==
1.0
else
scale
for
coalesced_grad
,
_
,
_
in
coalesced_grads_and_vars
:
# need to div nranks
div_factor
=
paddle
.
to_tensor
(
nranks
,
dtype
=
coalesced_grad
.
dtype
)
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
type
=
"elementwise_div"
,
inputs
=
{
'X'
:
coalesced_grad
,
'Y'
:
div_factor
},
outputs
=
{
'Out'
:
coalesced_grad
},
attrs
=
{
'axis'
:
-
1
},
)
if
scale
is
not
None
:
div_factor
=
paddle
.
to_tensor
(
scale
,
dtype
=
coalesced_grad
.
dtype
)
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
type
=
"elementwise_div"
,
inputs
=
{
'X'
:
coalesced_grad
,
'Y'
:
div_factor
},
outputs
=
{
'Out'
:
coalesced_grad
},
attrs
=
{
'axis'
:
-
1
},
)
paddle
.
distributed
.
all_reduce
(
coalesced_grad
,
group
=
comm_group
)
_split_tensors
(
coalesced_grads_and_vars
)
def
_apply_collective_grads_eager
(
parameters
,
comm_group
):
def
_apply_collective_grads_eager
(
parameters
,
comm_group
,
bucket_size
,
scale
=
None
):
grad_var_set
=
set
()
grad_vars
=
[]
...
...
@@ -76,16 +83,21 @@ def _apply_collective_grads_eager(parameters, comm_group):
assert
g_var
not
in
grad_var_set
grad_var_set
.
add
(
g_var
)
coalesced_grads_and_vars
=
build_groups
(
grad_vars
,
128
*
1024
*
1024
)
coalesced_grads_and_vars
=
build_groups
(
grad_vars
,
bucket_size
)
nranks
=
(
paddle
.
distributed
.
get_world_size
()
if
comm_group
is
None
else
comm_group
.
nranks
)
scale
=
1.0
/
nranks
if
scale
is
None
else
scale
scale
=
None
if
scale
==
1.0
else
scale
for
coalesced_grad
,
_
,
_
in
coalesced_grads_and_vars
:
# need to div nranks
coalesced_grad
.
scale_
(
1.0
/
nranks
)
if
scale
is
not
None
:
coalesced_grad
.
scale_
(
scale
)
paddle
.
distributed
.
all_reduce
(
coalesced_grad
,
group
=
comm_group
)
_split_tensors
(
coalesced_grads_and_vars
)
...
...
@@ -172,16 +184,22 @@ def broadcast_dp_parameters(model, hcg):
)
def
fused_allreduce_gradients
(
parameter_list
,
hcg
):
data_parallel_group
=
None
if
hcg
is
None
else
hcg
.
get_data_parallel_group
()
logger
.
debug
(
"dp start fuse allreduce gradients"
)
def
fused_allreduce_gradients
_with_group
(
parameter_list
,
group
,
bucket_size
=
128
*
1024
*
1024
,
scale
=
None
):
apply_func
=
(
_apply_collective_grads_eager
if
in_dygraph_mode
()
else
_apply_collective_grads
)
with
framework
.
no_grad
():
apply_func
(
parameter_list
,
data_parallel_group
)
apply_func
(
parameter_list
,
group
,
bucket_size
)
def
fused_allreduce_gradients
(
parameter_list
,
hcg
):
data_parallel_group
=
None
if
hcg
is
None
else
hcg
.
get_data_parallel_group
()
logger
.
debug
(
"dp start fuse allreduce gradients"
)
fused_allreduce_gradients_with_group
(
parameter_list
,
data_parallel_group
)
def
sharding_reduce_gradients
(
parameter_list
,
hcg
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录