Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
747000dd
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
747000dd
编写于
1月 06, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
1月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Pass bugfix (#38741)
上级
aec493c0
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
60 addition
and
83 deletion
+60
-83
python/paddle/distributed/auto_parallel/parallelizer.py
python/paddle/distributed/auto_parallel/parallelizer.py
+20
-33
python/paddle/distributed/passes/auto_parallel_sharding.py
python/paddle/distributed/passes/auto_parallel_sharding.py
+36
-30
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
...ttests/distributed_passes/auto_parallel_pass_test_base.py
+2
-2
python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py
...ts/distributed_passes/test_auto_parallel_sharding_pass.py
+1
-1
python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py
...le/fluid/tests/unittests/test_auto_parallel_cost_model.py
+0
-3
python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
...paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
+1
-2
python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py
...uid/tests/unittests/test_auto_parallel_partitioner_gpt.py
+0
-4
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
...addle/fluid/tests/unittests/test_auto_parallel_reshard.py
+0
-3
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py
...luid/tests/unittests/test_auto_parallel_reshard_dpmppp.py
+0
-3
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
.../fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
+0
-2
未找到文件。
python/paddle/distributed/auto_parallel/parallelizer.py
浏览文件 @
747000dd
...
...
@@ -23,6 +23,7 @@ import logging
import
pickle
import
time
import
paddle
from
paddle.fluid.backward
import
append_backward
from
paddle.distributed.utils
import
get_logger
from
paddle.distributed.fleet
import
cloud_utils
import
paddle.fluid.core
as
core
...
...
@@ -96,49 +97,35 @@ class AutoParallelizer:
if
suffix
in
attr_name
:
op
.
_remove_attr
(
attr_name
)
def
_apply_serial_
forward_
pass
(
self
,
main_program
,
startup_program
):
def
_apply_serial_pass
(
self
,
main_program
,
startup_program
):
# apply amp
forward
pass
# apply amp pass
if
self
.
_dist_strategy
.
amp
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp_pass"
,
self
.
_dist_strategy
.
amp_configs
)
auto_parallel_amp_pass
.
apply
_forward
(
main_program
,
startup_program
,
self
.
_pass_context
)
auto_parallel_amp_pass
.
apply
(
main_program
,
startup_program
,
self
.
_pass_context
)
# apply recompute
forward
pass
# apply recompute pass
if
self
.
_dist_strategy
.
recompute
:
auto_parallel_recompute_pass
=
new_pass
(
"auto_parallel_recompute_pass"
,
self
.
_dist_strategy
.
recompute_configs
)
auto_parallel_recompute_pass
.
apply
_forward
(
main_program
,
startup_program
,
self
.
_pass_context
)
auto_parallel_recompute_pass
.
apply
(
main_program
,
startup_program
,
self
.
_pass_context
)
def
_generate_backward
(
self
,
main_program
,
startup_program
,
loss
,
parameter_list
,
no_grad_set
,
callbacks
):
# apply recompute backward pass
if
self
.
_dist_strategy
.
recompute
:
assert
auto_parallel_recompute_pass
auto_parallel_recompute_pass
.
apply_forward
(
main_program
,
startup_program
,
parameter_list
,
no_grad_set
,
self
.
_pass_context
)
else
:
from
paddle.fluid.backward
import
append_backward
with
program_guard
(
main_program
,
startup_program
):
params_grads
=
append_backward
(
loss
,
parameter_list
,
no_grad_set
,
callbacks
,
distop_context
=
self
.
_dist_context
.
dist_op_context
)
complete_backward_annotation
(
main_program
,
dist_context
=
self
.
_dist_context
)
# apply amp forward pass
if
self
.
_dist_strategy
.
amp
:
assert
auto_parallel_amp_pass
auto_parallel_amp_pass
.
apply_backward
(
main_program
,
startup_program
,
self
.
_pass_context
)
with
program_guard
(
main_program
,
startup_program
):
params_grads
=
append_backward
(
loss
,
parameter_list
,
no_grad_set
,
callbacks
,
distop_context
=
self
.
_dist_context
.
dist_op_context
)
complete_backward_annotation
(
main_program
,
dist_context
=
self
.
_dist_context
)
return
params_grads
...
...
@@ -192,14 +179,14 @@ class AutoParallelizer:
completed_main_program
=
serial_main_program
self
.
_dist_context
=
copy
.
deepcopy
(
dist_context
)
# serial forward pass
self
.
_apply_serial_forward_pass
(
completed_main_program
,
serial_startup_program
)
# serial backward pass
params_grads
=
self
.
_generate_backward
(
completed_main_program
,
serial_startup_program
,
serial_loss
,
self
.
_parameter_list
,
self
.
_no_grad_set
,
self
.
_callbacks
)
# serial forward pass
self
.
_apply_serial_pass
(
completed_main_program
,
serial_startup_program
)
# Logical partition
rank
=
paddle
.
distributed
.
get_rank
()
partitioner
=
Partitioner
(
self
.
_dist_context
,
rank
)
...
...
python/paddle/distributed/passes/auto_parallel_sharding.py
浏览文件 @
747000dd
...
...
@@ -94,7 +94,7 @@ class ShardingPass(PassBase):
def
_collective_data_parallel_groups
(
self
,
main_block
):
for
op
in
main_block
.
ops
:
if
op
.
type
in
_skip_ops
:
if
not
_is_forward_op
(
op
)
or
op
.
type
in
_skip_ops
:
continue
group
=
_inference_data_parallel_group_for_operator
(
self
.
global_rank
,
op
,
self
.
_dist_context
)
...
...
@@ -106,7 +106,7 @@ class ShardingPass(PassBase):
if
len
(
self
.
dp_groups
)
!=
1
:
raise
NotImplementedError
(
"So far Only and Exactly one data parallel group in network are supported, but got [{}] different data parallel groups"
.
format
(
len
(
groups
)))
format
(
len
(
self
.
dp_
groups
)))
def
_build_sharding_infos
(
self
,
params_grads
):
...
...
@@ -193,18 +193,32 @@ class ShardingPass(PassBase):
return
# TODO (JZ-LIANG) support calculate global norm with tensor parallelism
is_clip_grad_by_global_norm
=
False
removed_op_type
=
[
'elementwise_mul'
,
'squared_l2_norm'
,
'clip_by_norm'
]
removed_op_idx
=
set
()
removed_tmp_var
=
set
()
for
idx
,
op
in
list
(
enumerate
(
main_block
.
ops
)):
if
not
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
is_clip_grad_by_global_norm
=
True
break
if
not
is_clip_grad_by_global_norm
:
return
removed_op_idx
=
set
()
removed_tmp_var
=
set
()
if
op
.
type
in
removed_op_type
:
input_name
=
op
.
input
(
"X"
)[
0
]
param_name
=
input_name
[:
input_name
.
find
(
"@GRAD"
)]
if
not
self
.
_is_parameter_in_local_shard
(
param_name
):
removed_op_idx
.
add
(
idx
)
if
op
.
type
in
[
'squared_l2_norm'
,
'clip_by_norm'
]:
for
output_name
in
op
.
output_arg_names
:
removed_tmp_var
.
add
(
output_name
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
main_block
.
ops
))):
if
not
_is_gradient_clip_op
(
op
):
continue
if
idx
in
removed_op_idx
:
main_block
.
_remove_op
(
idx
,
sync
=
False
)
for
varname
in
removed_tmp_var
:
main_block
.
_remove_var
(
varname
,
sync
=
False
)
for
idx
,
op
in
list
(
enumerate
(
main_block
.
ops
)):
if
not
_is_gradient_clip_op
(
op
):
continue
...
...
@@ -218,7 +232,7 @@ class ShardingPass(PassBase):
sum_op_output
=
op
.
desc
.
output_arg_names
()[
0
]
for
i
,
sharding_info
in
enumerate
(
self
.
sharding_infos
):
new_op
=
main_block
.
_insert_op
(
idx
+
i
,
idx
+
i
+
1
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
sum_op_output
]},
outputs
=
{
'Out'
:
[
sum_op_output
]},
...
...
@@ -235,21 +249,6 @@ class ShardingPass(PassBase):
new_op
,
dist_attr
.
process_mesh
,
dist_attr
.
dims_mapping
,
self
.
_dist_context
)
break
for
input_name
in
op
.
input_arg_names
:
param_name
=
input_name
[:
input_name
.
find
(
"@GRAD"
)]
if
not
self
.
_is_parameter_in_local_shard
(
param_name
):
removed_op_idx
.
add
(
idx
)
for
output_name
in
op
.
output_arg_names
:
removed_tmp_var
.
add
(
output_name
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
main_block
.
ops
))):
if
not
_is_gradient_clip_op
(
op
):
continue
if
idx
in
removed_op_idx
:
main_block
.
_remove_op
(
idx
,
sync
=
False
)
for
varname
in
removed_tmp_var
:
main_block
.
_remove_var
(
varname
,
sync
=
False
)
main_block
.
_sync_with_cpp
()
...
...
@@ -424,12 +423,15 @@ class ShardingPass(PassBase):
startup_block
.
_remove_op
(
idx
,
sync
=
False
)
continue
if
op
.
type
!=
"c_broadcast"
and
output_name
in
not_used_param_nane
:
if
op
.
type
!=
"c_broadcast"
and
output_name
in
param_usage
and
sharding_info
.
get_var_rank
(
output_name
)
!=
sharding_info
.
local_rank
:
startup_block
.
_remove_op
(
idx
,
sync
=
False
)
for
varname
in
not_used_param_nane
:
main_block
.
_remove_var
(
varname
,
sync
=
False
)
startup_block
.
_remove_var
(
varname
,
sync
=
False
)
for
param_name
in
param_usage
:
if
sharding_info
.
get_var_rank
(
param_name
)
!=
sharding_info
.
local_rank
:
main_block
.
_remove_var
(
param_name
,
sync
=
False
)
startup_block
.
_remove_var
(
param_name
,
sync
=
False
)
main_block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
...
...
@@ -594,6 +596,10 @@ def _is_param_grad_allreduce_op(op, block, dp_ring_ids):
return
block
.
var
(
base_name
).
is_parameter
def
_is_forward_op
(
op
):
return
op
.
attr
(
"op_role"
)
==
0
def
_inference_data_parallel_group_for_operator
(
rank_id
,
op
,
dist_context
):
dp_group
=
None
...
...
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
浏览文件 @
747000dd
...
...
@@ -178,13 +178,13 @@ class AutoPallelPassTestBase(DistPassTestBase):
preds
=
model
(
tokens
,
position_ids
,
attention_mask
)
criterion
=
GPTPretrainingCriterion
()
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
1.0
)
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.00001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
None
)
grad_clip
=
clip
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
startup_program
=
paddle
.
static
.
default_startup_program
()
_
,
_
,
dist_startup_prog
,
dist_main_prog
=
optimizer
.
minimize
(
...
...
python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py
浏览文件 @
747000dd
...
...
@@ -46,7 +46,7 @@ class TestShardingPass(AutoPallelPassTestBase):
dist_strategy
.
sharding
=
True
dist_strategy
.
sharding_configs
=
{
"sharding_degree"
:
2
,
"stage"
:
3
,
"stage"
:
2
,
}
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py
浏览文件 @
747000dd
...
...
@@ -157,9 +157,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
startup_program
,
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
浏览文件 @
747000dd
...
...
@@ -478,8 +478,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
# auto completion
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
startup_program
,
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py
浏览文件 @
747000dd
...
...
@@ -884,10 +884,6 @@ class TestGPTPartitioner(unittest.TestCase):
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
# serial forward pass
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
# serial backward pass
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
浏览文件 @
747000dd
...
...
@@ -155,9 +155,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
startup_program
,
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py
浏览文件 @
747000dd
...
...
@@ -119,9 +119,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
startup_program
,
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
浏览文件 @
747000dd
...
...
@@ -134,8 +134,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
# serial forward & backward completion
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
parallelizer
.
_apply_serial_forward_pass
(
complete_train_program
,
startup_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录