Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
74c9b57b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
74c9b57b
编写于
6月 29, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
6月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto parallel] Bug fixed for GPT3 benchmark (#43793)
* fixed bug for pass & engine * fixed bug for benchmark GPT-3
上级
ccfde2da
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
22 addition
and
14 deletion
+22
-14
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+5
-4
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+2
-1
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+10
-6
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+4
-2
python/paddle/distributed/passes/auto_parallel_amp.py
python/paddle/distributed/passes/auto_parallel_amp.py
+1
-1
未找到文件。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
74c9b57b
...
...
@@ -83,6 +83,7 @@ class Engine:
self
.
_dist_startup_progs
=
defaultdict
(
dict
)
# dist startup programs
self
.
_feed_vars
=
{}
self
.
_fetch_vars
=
{}
self
.
_planners
=
{}
def
prepare
(
self
,
optimizer
=
None
,
...
...
@@ -116,13 +117,13 @@ class Engine:
self
.
_planned_mode
=
None
self
.
_modes
=
[
'train'
,
'eval'
,
'predict'
]
# Build forward program
self
.
_build
()
# Do auto parallel process
for
mode
in
self
.
_modes
:
# Do the planning process
self
.
_plan
(
mode
)
for
mode
in
self
.
_modes
:
# Do the parallel process
self
.
_parallel
(
mode
,
all_ranks
)
# Init comm and startup program
...
...
@@ -185,14 +186,14 @@ class Engine:
else
:
self
.
_init_dist_context
(
mode
)
self
.
planner
=
Planner
(
mode
,
self
.
_dist_contexts
[
mode
])
self
.
planner
.
plan
()
self
.
_planners
[
mode
]
=
Planner
(
mode
,
self
.
_dist_contexts
[
mode
])
self
.
_planners
[
mode
]
.
plan
()
def
_parallel
(
self
,
mode
,
all_ranks
):
# Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update.
parallelizer
=
Parallelizer
(
mode
,
self
.
planner
.
completer
,
parallelizer
=
Parallelizer
(
mode
,
self
.
_planners
[
mode
]
.
completer
,
self
.
_dist_contexts
[
mode
])
if
not
all_ranks
:
parallelizer
.
parallel
(
self
.
_cur_rank
)
...
...
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
74c9b57b
...
...
@@ -18,7 +18,8 @@ from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_containers
=
{}
_g_elementwise_ops
=
[
"elementwise"
,
"gelu"
,
"dropout"
,
"cast"
,
"gather"
,
"concat"
"elementwise"
,
"gelu"
,
"dropout"
,
"cast"
,
"gather"
,
"concat"
,
"fused_softmax_mask_upper_triangle"
]
BACKWARD_ONLY_DIST_OPS
=
{
'check_finite_and_unscale'
,
'update_loss_scaling'
}
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
74c9b57b
...
...
@@ -80,9 +80,9 @@ class Parallelizer:
rank
,
dist_params_grads
)
else
:
# Apply pre optimization passes
self
.
_apply_pre_optimization
(
serial_main_program
,
serial_startup_program
,
None
,
None
,
None
)
#
self._apply_pre_optimization(serial_main_program,
#
serial_startup_program, None, None,
#
None)
# Do logical partition
partitioner
=
Partitioner
(
self
.
_dist_context
,
rank
)
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
partitioner
.
partition
(
...
...
@@ -121,7 +121,9 @@ class Parallelizer:
if
self
.
_strategy
is
None
:
return
# apply amp pass
if
self
.
_strategy
.
amp
:
# FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future
if
self
.
_mode
==
'train'
and
self
.
_strategy
.
amp
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
amp_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"params_grads"
]
=
params_grads
...
...
@@ -139,7 +141,8 @@ class Parallelizer:
self
.
_pass_context
)
# apply recompute pass
if
self
.
_strategy
.
recompute
:
# recompute is then train-only optimization
if
self
.
_mode
==
"train"
and
self
.
_strategy
.
recompute
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
recompute_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"no_grad_set"
]
=
None
...
...
@@ -164,7 +167,8 @@ class Parallelizer:
auto_parallel_sharding_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
if
self
.
_strategy
.
gradient_merge
:
# recompute is then train-only optimization
if
self
.
_mode
==
"train"
and
self
.
_strategy
.
gradient_merge
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
gradient_merge_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"params_grads"
]
=
params_grads
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
74c9b57b
...
...
@@ -1057,13 +1057,15 @@ def set_grad_var_shape(program, dist_context):
"transpose2_grad"
,
"softmax_grad"
,
"cross_entropy_grad2"
,
"dropout_grad"
,
"tanh_grad"
,
"slice"
,
"assign"
,
"matmul_v2_triple_grad"
,
"elementwise_add_triple_grad"
,
"fill_constant"
,
"sqrt_grad"
"fill_constant"
,
"sqrt_grad"
,
"fused_softmax_mask_upper_triangle_grad"
]
forward_list
=
[
"reshape2"
,
"softmax_with_cross_entropy"
,
"transpose2"
,
"softmax"
,
"cross_entropy2"
,
"dropout"
,
"tanh"
,
[
"slice_grad"
,
"c_allgather"
],
"assign"
,
"matmul_v2_grad_grad"
,
"elementwise_add_grad_grad"
,
"shape"
,
"sqrt"
"elementwise_add_grad_grad"
,
"shape"
,
"sqrt"
,
"fused_softmax_mask_upper_triangle_grad"
]
if
op
.
type
in
need_set_shape_list
:
for
forward_op
in
block
.
ops
:
...
...
python/paddle/distributed/passes/auto_parallel_amp.py
浏览文件 @
74c9b57b
...
...
@@ -143,8 +143,8 @@ class AMPState(object):
"""
num_cast_ops
=
0
var_name_dict
=
{}
for
in_name
in
op
.
input_names
:
var_name_dict
=
{}
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
op
,
in_name
):
continue
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录