Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3576e49c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3576e49c
编写于
9月 09, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
9月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] adapt gradient merge pass (#45915)
* adapt gradient merge * fix op_role * fix strategy
上级
369a235d
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
75 addition
and
34 deletion
+75
-34
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+18
-7
python/paddle/distributed/auto_parallel/parallelizer.py
python/paddle/distributed/auto_parallel/parallelizer.py
+13
-3
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+1
-1
python/paddle/distributed/passes/auto_parallel_fp16.py
python/paddle/distributed/passes/auto_parallel_fp16.py
+5
-5
python/paddle/distributed/passes/auto_parallel_grad_clip.py
python/paddle/distributed/passes/auto_parallel_grad_clip.py
+5
-1
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
...paddle/distributed/passes/auto_parallel_gradient_merge.py
+0
-9
python/paddle/distributed/passes/auto_parallel_sharding.py
python/paddle/distributed/passes/auto_parallel_sharding.py
+31
-7
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
...ttests/distributed_passes/auto_parallel_pass_test_base.py
+2
-1
未找到文件。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
3576e49c
...
@@ -60,14 +60,10 @@ class Engine:
...
@@ -60,14 +60,10 @@ class Engine:
strategy
=
None
,
strategy
=
None
,
user_tuning_config
=
None
):
user_tuning_config
=
None
):
self
.
model
=
model
self
.
model
=
model
self
.
strategy
=
strategy
or
fleet
.
DistributedStrategy
()
self
.
inputs_spec
=
self
.
_validate_spec
(
inputs_spec
)
self
.
inputs_spec
=
self
.
_validate_spec
(
inputs_spec
)
self
.
labels_spec
=
self
.
_validate_spec
(
labels_spec
)
self
.
labels_spec
=
self
.
_validate_spec
(
labels_spec
)
self
.
cluster
=
cluster
self
.
cluster
=
cluster
or
get_default_cluster
()
if
self
.
cluster
is
None
:
self
.
cluster
=
get_default_cluster
()
self
.
strategy
=
strategy
if
self
.
strategy
is
None
:
self
.
strategy
=
fleet
.
DistributedStrategy
()
self
.
_user_tuning_config
=
user_tuning_config
self
.
_user_tuning_config
=
user_tuning_config
self
.
_executor
=
None
self
.
_executor
=
None
...
@@ -433,7 +429,7 @@ class Engine:
...
@@ -433,7 +429,7 @@ class Engine:
break
break
train_logs
[
"step: {:d} "
]
=
step
train_logs
[
"step: {:d} "
]
=
step
if
lr_scheduler
is
not
None
:
if
lr_scheduler
is
not
None
and
step
%
self
.
k_steps
==
0
:
lr_scheduler
.
step
()
lr_scheduler
.
step
()
try
:
try
:
train_logs
[
"lr: {:5e} "
]
=
self
.
_lr_optimizer
.
get_lr
()
train_logs
[
"lr: {:5e} "
]
=
self
.
_lr_optimizer
.
get_lr
()
...
@@ -551,6 +547,12 @@ class Engine:
...
@@ -551,6 +547,12 @@ class Engine:
epochs
=
1
,
epochs
=
1
,
steps_per_epoch
=
None
,
steps_per_epoch
=
None
,
collate_fn
=
None
):
collate_fn
=
None
):
if
self
.
strategy
.
gradient_merge
and
batch_size
is
not
None
:
assert
batch_size
%
self
.
k_steps
==
0
,
\
"Requires batch_size:[{}] to be divisible by k_steps:[{}]."
.
format
(
batch_size
,
self
.
k_steps
)
batch_size
//=
self
.
k_steps
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_startup_prog
=
self
.
_dist_startup_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_startup_prog
=
self
.
_dist_startup_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
mode
]
dist_context
=
self
.
_dist_contexts
[
self
.
mode
]
...
@@ -612,6 +614,9 @@ class Engine:
...
@@ -612,6 +614,9 @@ class Engine:
def
_validate_spec
(
self
,
specs
):
def
_validate_spec
(
self
,
specs
):
specs
=
to_list
(
specs
)
specs
=
to_list
(
specs
)
self
.
k_steps
=
1
if
self
.
strategy
.
gradient_merge
:
self
.
k_steps
=
self
.
strategy
.
gradient_merge_configs
[
'k_steps'
]
if
specs
is
not
None
:
if
specs
is
not
None
:
for
i
,
spec
in
enumerate
(
specs
):
for
i
,
spec
in
enumerate
(
specs
):
assert
isinstance
(
spec
,
InputSpec
)
assert
isinstance
(
spec
,
InputSpec
)
...
@@ -619,6 +624,12 @@ class Engine:
...
@@ -619,6 +624,12 @@ class Engine:
raise
ValueError
(
raise
ValueError
(
"Requires Input[{}].name != None, but receive `None` with {}."
"Requires Input[{}].name != None, but receive `None` with {}."
.
format
(
i
,
spec
))
.
format
(
i
,
spec
))
if
self
.
k_steps
>
1
:
shape
=
list
(
spec
.
shape
)
assert
shape
[
0
]
%
self
.
k_steps
==
0
,
\
"Requires batch_size[{}] to be divisible by k_steps[{}]."
.
format
(
spec
.
shape
[
0
],
self
.
k_steps
)
shape
[
0
]
//=
self
.
k_steps
spec
.
shape
=
shape
return
specs
return
specs
def
_is_local_var
(
self
,
var
):
def
_is_local_var
(
self
,
var
):
...
...
python/paddle/distributed/auto_parallel/parallelizer.py
浏览文件 @
3576e49c
...
@@ -84,7 +84,7 @@ class AutoParallelizer:
...
@@ -84,7 +84,7 @@ class AutoParallelizer:
self
.
_need_rank_mapping
=
os
.
getenv
(
"PADDLE_NEED_RANK_MAPPING"
)
self
.
_need_rank_mapping
=
os
.
getenv
(
"PADDLE_NEED_RANK_MAPPING"
)
self
.
_need_rank_mapping
=
True
if
self
.
_need_rank_mapping
and
\
self
.
_need_rank_mapping
=
True
if
self
.
_need_rank_mapping
and
\
self
.
_need_rank_mapping
.
lower
()
==
'true'
else
False
self
.
_need_rank_mapping
.
lower
()
==
'true'
else
False
self
.
_pass_context
=
None
#
self._pass_context = None
def
_remove_distributed_attrs
(
self
,
main_program
):
def
_remove_distributed_attrs
(
self
,
main_program
):
suffix
=
core
.
kAutoParallelSuffix
()
suffix
=
core
.
kAutoParallelSuffix
()
...
@@ -143,10 +143,11 @@ class AutoParallelizer:
...
@@ -143,10 +143,11 @@ class AutoParallelizer:
def
_apply_optimize
(
self
,
main_program
,
startup_program
,
params_grads
):
def
_apply_optimize
(
self
,
main_program
,
startup_program
,
params_grads
):
optimizer
=
copy
.
deepcopy
(
self
.
_optimizer
)
with
program_guard
(
main_program
,
startup_program
):
with
program_guard
(
main_program
,
startup_program
):
optimize_ops
=
copy
.
deepcopy
(
optimize_ops
=
optimizer
.
apply_gradients
(
params_grads
)
self
.
_optimizer
).
apply_gradients
(
params_grads
)
self
.
_dist_context
.
_lr_optimizer
=
optimizer
# update completion
# update completion
self
.
_completer
=
Completer
(
self
.
_dist_context
)
self
.
_completer
=
Completer
(
self
.
_dist_context
)
self
.
_completer
.
complete_update_annotation
(
main_program
)
self
.
_completer
.
complete_update_annotation
(
main_program
)
...
@@ -165,6 +166,15 @@ class AutoParallelizer:
...
@@ -165,6 +166,15 @@ class AutoParallelizer:
config
)
config
)
auto_parallel_sharding_pass
.
apply
([
main_program
],
[
startup_program
],
auto_parallel_sharding_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
self
.
_pass_context
)
params_grads
=
self
.
_pass_context
.
get_attr
(
"params_grads"
)
config
=
copy
.
deepcopy
(
self
.
_dist_strategy
.
sharding_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"params_grads"
]
=
params_grads
config
[
"rank_id"
]
=
rank
auto_parallel_clip_pass
=
new_pass
(
"auto_parallel_grad_clip"
,
config
)
auto_parallel_clip_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
if
self
.
_dist_strategy
.
gradient_merge
:
if
self
.
_dist_strategy
.
gradient_merge
:
config
=
copy
.
deepcopy
(
self
.
_dist_strategy
.
gradient_merge_configs
)
config
=
copy
.
deepcopy
(
self
.
_dist_strategy
.
gradient_merge_configs
)
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
3576e49c
...
@@ -230,9 +230,9 @@ class Parallelizer:
...
@@ -230,9 +230,9 @@ class Parallelizer:
config
)
config
)
auto_parallel_sharding_pass
.
apply
([
main_program
],
[
startup_program
],
auto_parallel_sharding_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
self
.
_pass_context
)
params_grads
=
self
.
_pass_context
.
get_attr
(
"params_grads"
)
# GradClip is train-only optimization
# GradClip is train-only optimization
if
self
.
_mode
==
"train"
:
if
self
.
_mode
==
"train"
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
sharding_configs
)
config
=
copy
.
deepcopy
(
self
.
_strategy
.
sharding_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"dist_context"
]
=
self
.
_dist_context
...
...
python/paddle/distributed/passes/auto_parallel_fp16.py
浏览文件 @
3576e49c
...
@@ -442,7 +442,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
...
@@ -442,7 +442,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
inputs
=
{
'X'
:
grads
,
'Scale'
:
loss_scaling
}
inputs
=
{
'X'
:
grads
,
'Scale'
:
loss_scaling
}
outputs
=
{
'Out'
:
grads
,
'FoundInfinite'
:
found_inf
}
outputs
=
{
'Out'
:
grads
,
'FoundInfinite'
:
found_inf
}
attrs
=
{
'op_role'
:
OpRole
.
Backward
}
attrs
=
{
'op_role'
:
OpRole
.
Optimize
}
new_op
=
main_block
.
append_op
(
type
=
'check_finite_and_unscale'
,
new_op
=
main_block
.
append_op
(
type
=
'check_finite_and_unscale'
,
inputs
=
inputs
,
inputs
=
inputs
,
outputs
=
outputs
,
outputs
=
outputs
,
...
@@ -575,18 +575,18 @@ class FP16Pass(AMPPass):
...
@@ -575,18 +575,18 @@ class FP16Pass(AMPPass):
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
found_infs
=
[]
found_infs
=
[]
if
fp32_grads
:
if
fp32_grads
:
with
main_program
.
_
backward_role_guard
(
):
with
main_program
.
_
optimized_guard
([]
):
_
,
found_inf_fp32
=
_check_and_update_gradient
(
_
,
found_inf_fp32
=
_check_and_update_gradient
(
fp32_grads
,
self
.
_loss_scaling
,
"@fp32"
,
fp32_grads
,
self
.
_loss_scaling
,
"@fp32"
,
self
.
dist_context
)
self
.
dist_context
)
found_infs
.
append
(
found_inf_fp32
)
found_infs
.
append
(
found_inf_fp32
)
if
fp16_grads
:
if
fp16_grads
:
with
main_program
.
_
backward_role_guard
(
):
with
main_program
.
_
optimized_guard
([]
):
_
,
found_inf_fp16
=
_check_and_update_gradient
(
_
,
found_inf_fp16
=
_check_and_update_gradient
(
fp16_grads
,
self
.
_loss_scaling
,
"@fp16"
,
fp16_grads
,
self
.
_loss_scaling
,
"@fp16"
,
self
.
dist_context
)
self
.
dist_context
)
found_infs
.
append
(
found_inf_fp16
)
found_infs
.
append
(
found_inf_fp16
)
with
main_program
.
_
backward_role_guard
(
):
with
main_program
.
_
optimized_guard
([]
):
block
=
main_program
.
global_block
()
block
=
main_program
.
global_block
()
all_infs
=
paddle
.
fluid
.
layers
.
concat
(
found_infs
)
all_infs
=
paddle
.
fluid
.
layers
.
concat
(
found_infs
)
...
@@ -608,7 +608,7 @@ class FP16Pass(AMPPass):
...
@@ -608,7 +608,7 @@ class FP16Pass(AMPPass):
block
,
self
.
dist_context
)
block
,
self
.
dist_context
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
with
main_program
.
_
backward_role_guard
(
):
with
main_program
.
_
optimized_guard
([]
):
if
fp32_grads
:
if
fp32_grads
:
self
.
_update_loss_scaling
(
fp32_grads
,
found_inf
)
self
.
_update_loss_scaling
(
fp32_grads
,
found_inf
)
if
fp16_grads
:
if
fp16_grads
:
...
...
python/paddle/distributed/passes/auto_parallel_grad_clip.py
浏览文件 @
3576e49c
...
@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase):
...
@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase):
super
(
ClipGradByGloblNormPass
,
self
).
__init__
()
super
(
ClipGradByGloblNormPass
,
self
).
__init__
()
self
.
set_attr
(
"rank_id"
,
None
)
self
.
set_attr
(
"rank_id"
,
None
)
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"params_grads"
,
None
)
def
_check_self
(
self
):
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
if
self
.
get_attr
(
"dist_context"
)
is
None
:
...
@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase):
...
@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context
=
self
.
get_attr
(
"dist_context"
)
dist_context
=
self
.
get_attr
(
"dist_context"
)
if
dist_context
.
_lr_optimizer
.
_grad_clip
is
None
:
if
dist_context
.
_lr_optimizer
.
_grad_clip
is
None
:
return
False
return
False
if
self
.
get_attr
(
"params_grads"
)
is
None
:
return
False
return
True
return
True
def
_check_conflict
(
self
,
other_pass
):
def
_check_conflict
(
self
,
other_pass
):
...
@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase):
...
@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context
=
self
.
get_attr
(
"dist_context"
,
None
)
dist_context
=
self
.
get_attr
(
"dist_context"
,
None
)
rank_id
=
self
.
get_attr
(
"rank_id"
,
None
)
rank_id
=
self
.
get_attr
(
"rank_id"
,
None
)
block
=
main_program
.
global_block
()
block
=
main_program
.
global_block
()
dist_params_grads
=
_get_params_grads
(
block
)
dist_params_grads
=
self
.
get_attr
(
"params_grads"
,
None
)
# dist_params_grads = _get_params_grads(block)
self
.
clip_helper
=
ClipHelper
(
dist_params_grads
,
rank_id
,
block
,
self
.
clip_helper
=
ClipHelper
(
dist_params_grads
,
rank_id
,
block
,
dist_context
)
dist_context
)
...
...
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
浏览文件 @
3576e49c
...
@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
...
@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
return
optimize_ops_desc
return
optimize_ops_desc
def
_remove_op_role_var
(
param
,
grad
):
op_maker
=
core
.
op_proto_and_checker_maker
op
=
grad
.
op
if
op
and
op
.
has_attr
(
op_maker
.
kOpRoleVarAttrName
()):
op
.
_remove_attr
(
op_maker
.
kOpRoleVarAttrName
())
def
_get_gm_cond_var
(
main_program
,
k_steps
,
dist_context
):
def
_get_gm_cond_var
(
main_program
,
k_steps
,
dist_context
):
main_block
=
main_program
.
global_block
()
main_block
=
main_program
.
global_block
()
# Add const var
# Add const var
...
@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
...
@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
param
.
type
!=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
param
.
type
!=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
),
"SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
),
"SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
_remove_op_role_var
(
param
,
grad
)
# {grad.name: gradient_merge_var.name} to rename opt inputs
# {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge
=
{}
grad_to_gradient_merge
=
{}
# {param: gradient_merge_var} to insert scale op and fill_constant op
# {param: gradient_merge_var} to insert scale op and fill_constant op
...
...
python/paddle/distributed/passes/auto_parallel_sharding.py
浏览文件 @
3576e49c
...
@@ -59,6 +59,7 @@ class ShardingPass(PassBase):
...
@@ -59,6 +59,7 @@ class ShardingPass(PassBase):
self
.
varname_to_sharding_info
=
{}
self
.
varname_to_sharding_info
=
{}
self
.
partial_sharding
=
False
self
.
partial_sharding
=
False
self
.
outer_dp_group
=
None
self
.
outer_dp_group
=
None
self
.
shared_params_grads
=
[]
def
_check_self
(
self
):
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
if
self
.
get_attr
(
"dist_context"
)
is
None
:
...
@@ -94,6 +95,8 @@ class ShardingPass(PassBase):
...
@@ -94,6 +95,8 @@ class ShardingPass(PassBase):
self
.
_shard_gradient_synchronization
(
main_block
)
self
.
_shard_gradient_synchronization
(
main_block
)
self
.
_shard_parameter
(
main_block
,
startup_block
)
self
.
_shard_parameter
(
main_block
,
startup_block
)
context
.
set_attr
(
"params_grads"
,
self
.
shared_params_grads
)
def
_build_sharding_groups
(
self
,
main_block
,
params_grads
):
def
_build_sharding_groups
(
self
,
main_block
,
params_grads
):
self
.
_collective_data_parallel_groups
(
main_block
)
self
.
_collective_data_parallel_groups
(
main_block
)
self
.
_build_sharding_infos
(
params_grads
)
self
.
_build_sharding_infos
(
params_grads
)
...
@@ -148,13 +151,10 @@ class ShardingPass(PassBase):
...
@@ -148,13 +151,10 @@ class ShardingPass(PassBase):
self
.
_dist_context
.
_sharding_group
=
sharding_group
self
.
_dist_context
.
_sharding_group
=
sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group
=
[
p
for
p
,
g
in
params_grads
]
assert
len
(
params_in_group
)
==
len
(
set
(
params_in_group
)),
"found duplicated param in params_grads"
sharding_info
=
ShardingInfo
(
sharding_group
,
self
.
global_rank
,
sharding_info
=
ShardingInfo
(
sharding_group
,
self
.
global_rank
,
params_
in_group
)
params_
grads
)
self
.
sharding_infos
.
append
(
sharding_info
)
self
.
sharding_infos
.
append
(
sharding_info
)
for
param
in
params_in_group
:
for
param
in
sharding_info
.
params
:
self
.
varname_to_sharding_info
[
param
.
name
]
=
sharding_info
self
.
varname_to_sharding_info
[
param
.
name
]
=
sharding_info
def
_shard_optimizer
(
self
,
main_block
,
startup_block
,
params_grads
,
def
_shard_optimizer
(
self
,
main_block
,
startup_block
,
params_grads
,
...
@@ -201,6 +201,7 @@ class ShardingPass(PassBase):
...
@@ -201,6 +201,7 @@ class ShardingPass(PassBase):
op
.
desc
.
set_output
(
'Out'
,
reversed_x
)
op
.
desc
.
set_output
(
'Out'
,
reversed_x
)
else
:
else
:
if
op
.
type
==
"check_finite_and_unscale"
:
if
op
.
type
==
"check_finite_and_unscale"
:
op_role
=
op
.
attr
(
'op_role'
)
out_name
=
op
.
output_arg_names
[
0
]
out_name
=
op
.
output_arg_names
[
0
]
out_var
=
main_block
.
vars
[
out_name
]
out_var
=
main_block
.
vars
[
out_name
]
main_block
.
_remove_op
(
idx
,
sync
=
False
)
main_block
.
_remove_op
(
idx
,
sync
=
False
)
...
@@ -212,6 +213,7 @@ class ShardingPass(PassBase):
...
@@ -212,6 +213,7 @@ class ShardingPass(PassBase):
"shape"
:
out_var
.
shape
,
"shape"
:
out_var
.
shape
,
"dtype"
:
out_var
.
dtype
,
"dtype"
:
out_var
.
dtype
,
"value"
:
0
,
"value"
:
0
,
OP_ROLE_KEY
:
op_role
,
})
})
else
:
else
:
main_block
.
_remove_op
(
idx
,
sync
=
False
)
main_block
.
_remove_op
(
idx
,
sync
=
False
)
...
@@ -313,6 +315,9 @@ class ShardingPass(PassBase):
...
@@ -313,6 +315,9 @@ class ShardingPass(PassBase):
if
varname
!=
param_name
if
varname
!=
param_name
])
])
main_block
.
_remove_op
(
idx
,
sync
=
False
)
main_block
.
_remove_op
(
idx
,
sync
=
False
)
else
:
self
.
shared_params_grads
.
append
(
self
.
_get_param_grad
(
param_name
))
for
idx
,
op
in
reversed
(
list
(
enumerate
(
startup_block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
startup_block
.
ops
))):
if
len
(
op
.
output_arg_names
)
==
1
and
op
.
output_arg_names
[
if
len
(
op
.
output_arg_names
)
==
1
and
op
.
output_arg_names
[
...
@@ -365,6 +370,13 @@ class ShardingPass(PassBase):
...
@@ -365,6 +370,13 @@ class ShardingPass(PassBase):
sharding_info
=
self
.
varname_to_sharding_info
[
param_name
]
sharding_info
=
self
.
varname_to_sharding_info
[
param_name
]
return
sharding_info
.
is_in_local_shard
(
param_name
)
return
sharding_info
.
is_in_local_shard
(
param_name
)
def
_get_param_grad
(
self
,
param_name
):
assert
param_name
in
self
.
varname_to_sharding_info
sharding_info
=
self
.
varname_to_sharding_info
[
param_name
]
p_g
=
sharding_info
.
get_param_grad
(
param_name
)
assert
p_g
is
not
None
return
p_g
def
_shard_gradient_synchronization
(
self
,
main_block
):
def
_shard_gradient_synchronization
(
self
,
main_block
):
if
self
.
stage
<
2
:
if
self
.
stage
<
2
:
...
@@ -705,9 +717,13 @@ def shard_parameters(params, group_size):
...
@@ -705,9 +717,13 @@ def shard_parameters(params, group_size):
class
ShardingInfo
(
object
):
class
ShardingInfo
(
object
):
def
__init__
(
self
,
group
,
rank
,
params
):
def
__init__
(
self
,
group
,
rank
,
params
_grads
):
self
.
group
=
group
self
.
group
=
group
self
.
params
=
params
self
.
params_grads
=
dict
([(
p
.
name
,
(
p
,
g
))
for
p
,
g
in
params_grads
])
assert
len
(
self
.
params_grads
)
==
len
(
set
(
self
.
params_grads
)),
"found duplicated param in params_grads"
self
.
params
=
[
p
for
p
,
_
in
params_grads
]
self
.
param_names
=
[
p
.
name
for
p
in
self
.
params
]
self
.
param_names
=
[
p
.
name
for
p
in
self
.
params
]
self
.
group_size
=
group
.
nranks
self
.
group_size
=
group
.
nranks
self
.
global_rank
=
rank
self
.
global_rank
=
rank
...
@@ -762,3 +778,11 @@ class ShardingInfo(object):
...
@@ -762,3 +778,11 @@ class ShardingInfo(object):
if
usage
>
0
:
if
usage
>
0
:
broadcast_vars
.
add
(
param
)
broadcast_vars
.
add
(
param
)
return
broadcast_vars
,
param_usage
return
broadcast_vars
,
param_usage
def
get_param_grad
(
self
,
param_name
):
if
not
self
.
is_in_local_shard
(
param_name
):
raise
ValueError
(
"param[{}] not in current rank."
.
format
(
param_name
))
if
param_name
not
in
self
.
params_grads
:
raise
ValueError
(
'param[{}] not in params_grads'
.
format
(
param_name
))
return
self
.
params_grads
.
get
(
param_name
,
None
)
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
浏览文件 @
3576e49c
...
@@ -178,6 +178,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
...
@@ -178,6 +178,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
preds
=
model
(
tokens
,
position_ids
,
attention_mask
)
preds
=
model
(
tokens
,
position_ids
,
attention_mask
)
criterion
=
GPTPretrainingCriterion
()
criterion
=
GPTPretrainingCriterion
()
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
1.0
)
if
kwargs
.
get
(
'optimizer'
,
None
)
==
"LarsMomentum"
:
if
kwargs
.
get
(
'optimizer'
,
None
)
==
"LarsMomentum"
:
optimizer
=
paddle
.
fluid
.
optimizer
.
LarsMomentumOptimizer
(
optimizer
=
paddle
.
fluid
.
optimizer
.
LarsMomentumOptimizer
(
...
@@ -188,7 +189,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
...
@@ -188,7 +189,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
beta1
=
0.9
,
beta1
=
0.9
,
beta2
=
0.999
,
beta2
=
0.999
,
epsilon
=
1e-08
,
epsilon
=
1e-08
,
grad_clip
=
None
)
grad_clip
=
clip
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
startup_program
=
paddle
.
static
.
default_startup_program
()
startup_program
=
paddle
.
static
.
default_startup_program
()
_
,
_
,
dist_startup_prog
,
dist_main_prog
=
optimizer
.
minimize
(
_
,
_
,
dist_startup_prog
,
dist_main_prog
=
optimizer
.
minimize
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录