Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
747000dd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
747000dd
编写于
1月 06, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
1月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Pass bugfix (#38741)
上级
aec493c0
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
60 addition
and
83 deletion
+60
-83
python/paddle/distributed/auto_parallel/parallelizer.py
python/paddle/distributed/auto_parallel/parallelizer.py
+20
-33
python/paddle/distributed/passes/auto_parallel_sharding.py
python/paddle/distributed/passes/auto_parallel_sharding.py
+36
-30
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
...ttests/distributed_passes/auto_parallel_pass_test_base.py
+2
-2
python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py
...ts/distributed_passes/test_auto_parallel_sharding_pass.py
+1
-1
python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py
...le/fluid/tests/unittests/test_auto_parallel_cost_model.py
+0
-3
python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
...paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
+1
-2
python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py
...uid/tests/unittests/test_auto_parallel_partitioner_gpt.py
+0
-4
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
...addle/fluid/tests/unittests/test_auto_parallel_reshard.py
+0
-3
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py
...luid/tests/unittests/test_auto_parallel_reshard_dpmppp.py
+0
-3
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
.../fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
+0
-2
未找到文件。
python/paddle/distributed/auto_parallel/parallelizer.py
浏览文件 @
747000dd
...
...
@@ -23,6 +23,7 @@ import logging
import
pickle
import
time
import
paddle
from
paddle.fluid.backward
import
append_backward
from
paddle.distributed.utils
import
get_logger
from
paddle.distributed.fleet
import
cloud_utils
import
paddle.fluid.core
as
core
...
...
@@ -96,49 +97,35 @@ class AutoParallelizer:
if
suffix
in
attr_name
:
op
.
_remove_attr
(
attr_name
)
def
_apply_serial_
forward_
pass
(
self
,
main_program
,
startup_program
):
def
_apply_serial_pass
(
self
,
main_program
,
startup_program
):
# apply amp
forward
pass
# apply amp pass
if
self
.
_dist_strategy
.
amp
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp_pass"
,
self
.
_dist_strategy
.
amp_configs
)
auto_parallel_amp_pass
.
apply
_forward
(
main_program
,
startup_program
,
self
.
_pass_context
)
auto_parallel_amp_pass
.
apply
(
main_program
,
startup_program
,
self
.
_pass_context
)
# apply recompute
forward
pass
# apply recompute pass
if
self
.
_dist_strategy
.
recompute
:
auto_parallel_recompute_pass
=
new_pass
(
"auto_parallel_recompute_pass"
,
self
.
_dist_strategy
.
recompute_configs
)
auto_parallel_recompute_pass
.
apply
_forward
(
main_program
,
startup_program
,
self
.
_pass_context
)
auto_parallel_recompute_pass
.
apply
(
main_program
,
startup_program
,
self
.
_pass_context
)
def
_generate_backward
(
self
,
main_program
,
startup_program
,
loss
,
parameter_list
,
no_grad_set
,
callbacks
):
# apply recompute backward pass
if
self
.
_dist_strategy
.
recompute
:
assert
auto_parallel_recompute_pass
auto_parallel_recompute_pass
.
apply_forward
(
main_program
,
startup_program
,
parameter_list
,
no_grad_set
,
self
.
_pass_context
)
else
:
from
paddle.fluid.backward
import
append_backward
with
program_guard
(
main_program
,
startup_program
):
params_grads
=
append_backward
(
loss
,
parameter_list
,
no_grad_set
,
callbacks
,
distop_context
=
self
.
_dist_context
.
dist_op_context
)
complete_backward_annotation
(
main_program
,
dist_context
=
self
.
_dist_context
)
# apply amp forward pass
if
self
.
_dist_strategy
.
amp
:
assert
auto_parallel_amp_pass
auto_parallel_amp_pass
.
apply_backward
(
main_program
,
startup_program
,
self
.
_pass_context
)
with
program_guard
(
main_program
,
startup_program
):
params_grads
=
append_backward
(
loss
,
parameter_list
,
no_grad_set
,
callbacks
,
distop_context
=
self
.
_dist_context
.
dist_op_context
)
complete_backward_annotation
(
main_program
,
dist_context
=
self
.
_dist_context
)
return
params_grads
...
...
@@ -192,14 +179,14 @@ class AutoParallelizer:
completed_main_program
=
serial_main_program
self
.
_dist_context
=
copy
.
deepcopy
(
dist_context
)
# serial forward pass
self
.
_apply_serial_forward_pass
(
completed_main_program
,
serial_startup_program
)
# serial backward pass
params_grads
=
self
.
_generate_backward
(
completed_main_program
,
serial_startup_program
,
serial_loss
,
self
.
_parameter_list
,
self
.
_no_grad_set
,
self
.
_callbacks
)
# serial forward pass
self
.
_apply_serial_pass
(
completed_main_program
,
serial_startup_program
)
# Logical partition
rank
=
paddle
.
distributed
.
get_rank
()
partitioner
=
Partitioner
(
self
.
_dist_context
,
rank
)
...
...
python/paddle/distributed/passes/auto_parallel_sharding.py
浏览文件 @
747000dd
...
...
@@ -94,7 +94,7 @@ class ShardingPass(PassBase):
def
_collective_data_parallel_groups
(
self
,
main_block
):
for
op
in
main_block
.
ops
:
if
op
.
type
in
_skip_ops
:
if
not
_is_forward_op
(
op
)
or
op
.
type
in
_skip_ops
:
continue
group
=
_inference_data_parallel_group_for_operator
(
self
.
global_rank
,
op
,
self
.
_dist_context
)
...
...
@@ -106,7 +106,7 @@ class ShardingPass(PassBase):
if
len
(
self
.
dp_groups
)
!=
1
:
raise
NotImplementedError
(
"So far Only and Exactly one data parallel group in network are supported, but got [{}] different data parallel groups"
.
format
(
len
(
groups
)))
format
(
len
(
self
.
dp_
groups
)))
def
_build_sharding_infos
(
self
,
params_grads
):
...
...
@@ -193,18 +193,32 @@ class ShardingPass(PassBase):
return
# TODO (JZ-LIANG) support calculate global norm with tensor parallelism
is_clip_grad_by_global_norm
=
False
removed_op_type
=
[
'elementwise_mul'
,
'squared_l2_norm'
,
'clip_by_norm'
]
removed_op_idx
=
set
()
removed_tmp_var
=
set
()
for
idx
,
op
in
list
(
enumerate
(
main_block
.
ops
)):
if
not
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
is_clip_grad_by_global_norm
=
True
break
if
not
is_clip_grad_by_global_norm
:
return
removed_op_idx
=
set
()
removed_tmp_var
=
set
()
if
op
.
type
in
removed_op_type
:
input_name
=
op
.
input
(
"X"
)[
0
]
param_name
=
input_name
[:
input_name
.
find
(
"@GRAD"
)]
if
not
self
.
_is_parameter_in_local_shard
(
param_name
):
removed_op_idx
.
add
(
idx
)
if
op
.
type
in
[
'squared_l2_norm'
,
'clip_by_norm'
]:
for
output_name
in
op
.
output_arg_names
:
removed_tmp_var
.
add
(
output_name
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
main_block
.
ops
))):
if
not
_is_gradient_clip_op
(
op
):
continue
if
idx
in
removed_op_idx
:
main_block
.
_remove_op
(
idx
,
sync
=
False
)
for
varname
in
removed_tmp_var
:
main_block
.
_remove_var
(
varname
,
sync
=
False
)
for
idx
,
op
in
list
(
enumerate
(
main_block
.
ops
)):
if
not
_is_gradient_clip_op
(
op
):
continue
...
...
@@ -218,7 +232,7 @@ class ShardingPass(PassBase):
sum_op_output
=
op
.
desc
.
output_arg_names
()[
0
]
for
i
,
sharding_info
in
enumerate
(
self
.
sharding_infos
):
new_op
=
main_block
.
_insert_op
(
idx
+
i
,
idx
+
i
+
1
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
sum_op_output
]},
outputs
=
{
'Out'
:
[
sum_op_output
]},
...
...
@@ -235,21 +249,6 @@ class ShardingPass(PassBase):
new_op
,
dist_attr
.
process_mesh
,
dist_attr
.
dims_mapping
,
self
.
_dist_context
)
break
for
input_name
in
op
.
input_arg_names
:
param_name
=
input_name
[:
input_name
.
find
(
"@GRAD"
)]
if
not
self
.
_is_parameter_in_local_shard
(
param_name
):
removed_op_idx
.
add
(
idx
)
for
output_name
in
op
.
output_arg_names
:
removed_tmp_var
.
add
(
output_name
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
main_block
.
ops
))):
if
not
_is_gradient_clip_op
(
op
):
continue
if
idx
in
removed_op_idx
:
main_block
.
_remove_op
(
idx
,
sync
=
False
)
for
varname
in
removed_tmp_var
:
main_block
.
_remove_var
(
varname
,
sync
=
False
)
main_block
.
_sync_with_cpp
()
...
...
@@ -424,12 +423,15 @@ class ShardingPass(PassBase):
startup_block
.
_remove_op
(
idx
,
sync
=
False
)
continue
if
op
.
type
!=
"c_broadcast"
and
output_name
in
not_used_param_nane
:
if
op
.
type
!=
"c_broadcast"
and
output_name
in
param_usage
and
sharding_info
.
get_var_rank
(
output_name
)
!=
sharding_info
.
local_rank
:
startup_block
.
_remove_op
(
idx
,
sync
=
False
)
for
varname
in
not_used_param_nane
:
main_block
.
_remove_var
(
varname
,
sync
=
False
)
startup_block
.
_remove_var
(
varname
,
sync
=
False
)
for
param_name
in
param_usage
:
if
sharding_info
.
get_var_rank
(
param_name
)
!=
sharding_info
.
local_rank
:
main_block
.
_remove_var
(
param_name
,
sync
=
False
)
startup_block
.
_remove_var
(
param_name
,
sync
=
False
)
main_block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
...
...
@@ -594,6 +596,10 @@ def _is_param_grad_allreduce_op(op, block, dp_ring_ids):
return
block
.
var
(
base_name
).
is_parameter
def
_is_forward_op
(
op
):
return
op
.
attr
(
"op_role"
)
==
0
def
_inference_data_parallel_group_for_operator
(
rank_id
,
op
,
dist_context
):
dp_group
=
None
...
...
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
浏览文件 @
747000dd
...
...
@@ -178,13 +178,13 @@ class AutoPallelPassTestBase(DistPassTestBase):
preds
=
model
(
tokens
,
position_ids
,
attention_mask
)
criterion
=
GPTPretrainingCriterion
()
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
1.0
)
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.00001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
None
)
grad_clip
=
clip
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
startup_program
=
paddle
.
static
.
default_startup_program
()
_
,
_
,
dist_startup_prog
,
dist_main_prog
=
optimizer
.
minimize
(
...
...
python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py
浏览文件 @
747000dd
...
...
@@ -46,7 +46,7 @@ class TestShardingPass(AutoPallelPassTestBase):
dist_strategy
.
sharding
=
True
dist_strategy
.
sharding_configs
=
{
"sharding_degree"
:
2
,
"stage"
:
3
,
"stage"
:
2
,
}
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py
浏览文件 @
747000dd
...
...
@@ -157,9 +157,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
startup_program
,
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
浏览文件 @
747000dd
...
...
@@ -478,8 +478,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
# auto completion
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
startup_program
,
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py
浏览文件 @
747000dd
...
...
@@ -884,10 +884,6 @@ class TestGPTPartitioner(unittest.TestCase):
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
# serial forward pass
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
# serial backward pass
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
浏览文件 @
747000dd
...
...
@@ -155,9 +155,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
startup_program
,
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py
浏览文件 @
747000dd
...
...
@@ -119,9 +119,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
startup_program
,
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
浏览文件 @
747000dd
...
...
@@ -134,8 +134,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
# serial forward & backward completion
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录