Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
74c9b57b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
74c9b57b
编写于
6月 29, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
6月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto parallel] Bug fixed for GPT3 benchmark (#43793)
* fixed bug for pass & engine * fixed bug for benchmark GPT-3
上级
ccfde2da
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
22 addition
and
14 deletion
+22
-14
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+5
-4
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+2
-1
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+10
-6
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+4
-2
python/paddle/distributed/passes/auto_parallel_amp.py
python/paddle/distributed/passes/auto_parallel_amp.py
+1
-1
未找到文件。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
74c9b57b
...
...
@@ -83,6 +83,7 @@ class Engine:
self
.
_dist_startup_progs
=
defaultdict
(
dict
)
# dist startup programs
self
.
_feed_vars
=
{}
self
.
_fetch_vars
=
{}
self
.
_planners
=
{}
def
prepare
(
self
,
optimizer
=
None
,
...
...
@@ -116,13 +117,13 @@ class Engine:
self
.
_planned_mode
=
None
self
.
_modes
=
[
'train'
,
'eval'
,
'predict'
]
# Build forward program
self
.
_build
()
# Do auto parallel process
for
mode
in
self
.
_modes
:
# Do the planning process
self
.
_plan
(
mode
)
for
mode
in
self
.
_modes
:
# Do the parallel process
self
.
_parallel
(
mode
,
all_ranks
)
# Init comm and startup program
...
...
@@ -185,14 +186,14 @@ class Engine:
else
:
self
.
_init_dist_context
(
mode
)
self
.
planner
=
Planner
(
mode
,
self
.
_dist_contexts
[
mode
])
self
.
planner
.
plan
()
self
.
_planners
[
mode
]
=
Planner
(
mode
,
self
.
_dist_contexts
[
mode
])
self
.
_planners
[
mode
]
.
plan
()
def
_parallel
(
self
,
mode
,
all_ranks
):
# Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update.
parallelizer
=
Parallelizer
(
mode
,
self
.
planner
.
completer
,
parallelizer
=
Parallelizer
(
mode
,
self
.
_planners
[
mode
]
.
completer
,
self
.
_dist_contexts
[
mode
])
if
not
all_ranks
:
parallelizer
.
parallel
(
self
.
_cur_rank
)
...
...
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
74c9b57b
...
...
@@ -18,7 +18,8 @@ from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_containers
=
{}
_g_elementwise_ops
=
[
"elementwise"
,
"gelu"
,
"dropout"
,
"cast"
,
"gather"
,
"concat"
"elementwise"
,
"gelu"
,
"dropout"
,
"cast"
,
"gather"
,
"concat"
,
"fused_softmax_mask_upper_triangle"
]
BACKWARD_ONLY_DIST_OPS
=
{
'check_finite_and_unscale'
,
'update_loss_scaling'
}
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
74c9b57b
...
...
@@ -80,9 +80,9 @@ class Parallelizer:
rank
,
dist_params_grads
)
else
:
# Apply pre optimization passes
self
.
_apply_pre_optimization
(
serial_main_program
,
serial_startup_program
,
None
,
None
,
None
)
#
self._apply_pre_optimization(serial_main_program,
#
serial_startup_program, None, None,
#
None)
# Do logical partition
partitioner
=
Partitioner
(
self
.
_dist_context
,
rank
)
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
partitioner
.
partition
(
...
...
@@ -121,7 +121,9 @@ class Parallelizer:
if
self
.
_strategy
is
None
:
return
# apply amp pass
if
self
.
_strategy
.
amp
:
# FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future
if
self
.
_mode
==
'train'
and
self
.
_strategy
.
amp
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
amp_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"params_grads"
]
=
params_grads
...
...
@@ -139,7 +141,8 @@ class Parallelizer:
self
.
_pass_context
)
# apply recompute pass
if
self
.
_strategy
.
recompute
:
# recompute is then train-only optimization
if
self
.
_mode
==
"train"
and
self
.
_strategy
.
recompute
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
recompute_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"no_grad_set"
]
=
None
...
...
@@ -164,7 +167,8 @@ class Parallelizer:
auto_parallel_sharding_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
if
self
.
_strategy
.
gradient_merge
:
# recompute is then train-only optimization
if
self
.
_mode
==
"train"
and
self
.
_strategy
.
gradient_merge
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
gradient_merge_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"params_grads"
]
=
params_grads
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
74c9b57b
...
...
@@ -1057,13 +1057,15 @@ def set_grad_var_shape(program, dist_context):
"transpose2_grad"
,
"softmax_grad"
,
"cross_entropy_grad2"
,
"dropout_grad"
,
"tanh_grad"
,
"slice"
,
"assign"
,
"matmul_v2_triple_grad"
,
"elementwise_add_triple_grad"
,
"fill_constant"
,
"sqrt_grad"
"fill_constant"
,
"sqrt_grad"
,
"fused_softmax_mask_upper_triangle_grad"
]
forward_list
=
[
"reshape2"
,
"softmax_with_cross_entropy"
,
"transpose2"
,
"softmax"
,
"cross_entropy2"
,
"dropout"
,
"tanh"
,
[
"slice_grad"
,
"c_allgather"
],
"assign"
,
"matmul_v2_grad_grad"
,
"elementwise_add_grad_grad"
,
"shape"
,
"sqrt"
"elementwise_add_grad_grad"
,
"shape"
,
"sqrt"
,
"fused_softmax_mask_upper_triangle_grad"
]
if
op
.
type
in
need_set_shape_list
:
for
forward_op
in
block
.
ops
:
...
...
python/paddle/distributed/passes/auto_parallel_amp.py
浏览文件 @
74c9b57b
...
...
@@ -143,8 +143,8 @@ class AMPState(object):
"""
num_cast_ops
=
0
var_name_dict
=
{}
for
in_name
in
op
.
input_names
:
var_name_dict
=
{}
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
op
,
in_name
):
continue
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录