Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
38ec37cd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
38ec37cd
编写于
4月 19, 2023
作者:
K
kangguangli
提交者:
GitHub
4月 19, 2023
1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Perf] fix static graph performance issue in amp mode with multicard (#52724)
* fix * fix * fix * fix * fix * fix fuse group order
上级
f6f18835
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
47 addition
and
14 deletion
+47
-14
python/paddle/distributed/fleet/fleet.py
python/paddle/distributed/fleet/fleet.py
+1
-1
python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py
...istributed/fleet/meta_optimizers/raw_program_optimizer.py
+44
-13
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+2
-0
未找到文件。
python/paddle/distributed/fleet/fleet.py
浏览文件 @
38ec37cd
...
@@ -1534,7 +1534,7 @@ class Fleet:
...
@@ -1534,7 +1534,7 @@ class Fleet:
# i.e. users can not modify current computation graph anymore
# i.e. users can not modify current computation graph anymore
context
[
"graph_optimize_ops"
]
=
optimize_ops
context
[
"graph_optimize_ops"
]
=
optimize_ops
context
[
"graph_optimize_grads"
]
=
params_grads
context
[
"graph_optimize_grads"
]
=
params_grads
el
s
e
:
el
if
loss
.
block
.
program
.
_pass_applied
is
Non
e
:
apply_ir_passes
(
loss
.
block
.
program
,
startup_program
,
self
)
apply_ir_passes
(
loss
.
block
.
program
,
startup_program
,
self
)
if
not
self
.
_role_maker
.
_is_heter_parameter_server_mode
:
if
not
self
.
_role_maker
.
_is_heter_parameter_server_mode
:
...
...
python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py
浏览文件 @
38ec37cd
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
from
paddle
import
static
from
paddle
import
static
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle.framework
import
_global_flags
from
paddle.framework.ir
import
apply_build_strategy
from
paddle.utils
import
unique_name
from
paddle.utils
import
unique_name
from
.common
import
(
from
.common
import
(
...
@@ -146,6 +148,18 @@ class RawProgramOptimizer(MetaOptimizerBase):
...
@@ -146,6 +148,18 @@ class RawProgramOptimizer(MetaOptimizerBase):
optimize_ops
,
params_grads
=
self
.
inner_opt
.
minimize
(
optimize_ops
,
params_grads
=
self
.
inner_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
)
if
_global_flags
()[
'FLAGS_apply_pass_to_program'
]:
pass_attrs
=
{
"use_cuda"
:
True
}
build_strategy
=
self
.
user_defined_strategy
.
build_strategy
.
_copy
()
build_strategy
.
fuse_all_optimizer_ops
=
False
build_strategy
.
fuse_all_reduce_ops
=
False
apply_build_strategy
(
self
.
main_program
,
self
.
startup_program
,
build_strategy
,
pass_attrs
,
)
self
.
main_program
.
_pass_applied
=
True
if
self
.
nranks
==
1
:
if
self
.
nranks
==
1
:
return
optimize_ops
,
params_grads
return
optimize_ops
,
params_grads
self
.
_init_process_group
()
self
.
_init_process_group
()
...
@@ -357,24 +371,39 @@ class RawProgramOptimizer(MetaOptimizerBase):
...
@@ -357,24 +371,39 @@ class RawProgramOptimizer(MetaOptimizerBase):
# [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
# [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
# each entry of the list is a tuple stores the grads segment list and
# each entry of the list is a tuple stores the grads segment list and
# the corresponding params segment list
# the corresponding params segment list
grad_param_segments
=
[]
last_dtype
=
None
# its type is: dict[dtype, list[tuple[list[grad], list[param]]]]
grad_param_segments_by_dtype
=
{}
# split the grad based on dtype and fused size
# split the grad based on dtype and fused size
for
param
,
grad
in
param_grads
:
for
param
,
grad
in
param_grads
:
if
(
if
grad
.
dtype
not
in
grad_param_segments_by_dtype
:
len
(
grad_param_segments
)
==
0
grad_param_segments_by_dtype
[
grad
.
dtype
]
=
[([],
[])]
or
len
(
grad_param_segments
[
-
1
][
0
])
==
self
.
fuse_grad_size_in_num
grad_segment
,
param_segment
=
grad_param_segments_by_dtype
[
or
grad
.
dtype
!=
last_dtype
grad
.
dtype
):
][
-
1
]
grad_param_segments
.
append
(([
grad
],
[
param
]))
if
len
(
param_segment
)
==
self
.
fuse_grad_size_in_num
:
last_dtype
=
grad
.
dtype
grad_param_segments_by_dtype
[
grad
.
dtype
].
append
(([],
[]))
else
:
grad_segment
,
param_segment
=
grad_param_segments_by_dtype
[
grad_param_segments
[
-
1
][
0
].
append
(
grad
)
grad
.
dtype
grad_param_segments
[
-
1
][
1
].
append
(
param
)
][
-
1
]
param_segment
.
append
(
param
)
grad_segment
.
append
(
grad
)
grad_param_segments
=
[]
for
_
,
group
in
grad_param_segments_by_dtype
.
items
():
grad_param_segments
.
extend
(
group
)
if
len
(
grad_param_segments
)
==
0
:
if
len
(
grad_param_segments
)
==
0
:
return
return
# because the regroup operation make the relative order invalid,
# we need to reorder these fuse group by after_idx
def
get_after_idx_of_fuse_group
(
grad_param_segments
):
grad_segment
,
param_segment
=
grad_param_segments
return
max
([
outputs_name_to_idx
[
grad
][
1
]
for
grad
in
grad_segment
])
grad_param_segments
.
sort
(
key
=
get_after_idx_of_fuse_group
)
fused_vars
=
[
None
]
*
len
(
grad_param_segments
)
fused_vars
=
[
None
]
*
len
(
grad_param_segments
)
for
i
in
range
(
len
(
grad_param_segments
)
-
1
,
-
1
,
-
1
):
for
i
in
range
(
len
(
grad_param_segments
)
-
1
,
-
1
,
-
1
):
# travers the grad_param_segments in backward
# travers the grad_param_segments in backward
...
@@ -390,7 +419,9 @@ class RawProgramOptimizer(MetaOptimizerBase):
...
@@ -390,7 +419,9 @@ class RawProgramOptimizer(MetaOptimizerBase):
stop_gradient
=
True
,
stop_gradient
=
True
,
)
)
fused_vars
[
i
]
=
fused_var
fused_vars
[
i
]
=
fused_var
after_idx
=
outputs_name_to_idx
[
grad_segment
[
-
1
]][
1
]
after_idx
=
max
(
[
outputs_name_to_idx
[
grad
][
1
]
for
grad
in
grad_segment
]
)
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
after_idx
+
1
,
after_idx
+
1
,
type
=
'c_allreduce_sum'
,
type
=
'c_allreduce_sum'
,
...
...
python/paddle/fluid/framework.py
浏览文件 @
38ec37cd
...
@@ -5335,6 +5335,8 @@ class Program:
...
@@ -5335,6 +5335,8 @@ class Program:
self
.
_fleet_opt
=
None
self
.
_fleet_opt
=
None
self
.
_program_config
=
None
self
.
_program_config
=
None
self
.
_pass_applied
=
None
# assigned if this program has been parsed by a pipeline optimizer
# assigned if this program has been parsed by a pipeline optimizer
self
.
_pipeline_opt
=
None
self
.
_pipeline_opt
=
None
...
...
saxon_zh
@saxon_zh
mentioned in commit
3603b9b1
·
6月 08, 2023
mentioned in commit
3603b9b1
mentioned in commit 3603b9b106adef17a8caa24080f6c312b07cb638
开关提交列表
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录