Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3576e49c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3576e49c
编写于
9月 09, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
9月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] adapt gradient merge pass (#45915)
* adapt gradient merge * fix op_role * fix strategy
上级
369a235d
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
75 addition
and
34 deletion
+75
-34
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+18
-7
python/paddle/distributed/auto_parallel/parallelizer.py
python/paddle/distributed/auto_parallel/parallelizer.py
+13
-3
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+1
-1
python/paddle/distributed/passes/auto_parallel_fp16.py
python/paddle/distributed/passes/auto_parallel_fp16.py
+5
-5
python/paddle/distributed/passes/auto_parallel_grad_clip.py
python/paddle/distributed/passes/auto_parallel_grad_clip.py
+5
-1
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
...paddle/distributed/passes/auto_parallel_gradient_merge.py
+0
-9
python/paddle/distributed/passes/auto_parallel_sharding.py
python/paddle/distributed/passes/auto_parallel_sharding.py
+31
-7
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
...ttests/distributed_passes/auto_parallel_pass_test_base.py
+2
-1
未找到文件。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
3576e49c
...
@@ -60,14 +60,10 @@ class Engine:
...
@@ -60,14 +60,10 @@ class Engine:
strategy
=
None
,
strategy
=
None
,
user_tuning_config
=
None
):
user_tuning_config
=
None
):
self
.
model
=
model
self
.
model
=
model
self
.
strategy
=
strategy
or
fleet
.
DistributedStrategy
()
self
.
inputs_spec
=
self
.
_validate_spec
(
inputs_spec
)
self
.
inputs_spec
=
self
.
_validate_spec
(
inputs_spec
)
self
.
labels_spec
=
self
.
_validate_spec
(
labels_spec
)
self
.
labels_spec
=
self
.
_validate_spec
(
labels_spec
)
self
.
cluster
=
cluster
self
.
cluster
=
cluster
or
get_default_cluster
()
if
self
.
cluster
is
None
:
self
.
cluster
=
get_default_cluster
()
self
.
strategy
=
strategy
if
self
.
strategy
is
None
:
self
.
strategy
=
fleet
.
DistributedStrategy
()
self
.
_user_tuning_config
=
user_tuning_config
self
.
_user_tuning_config
=
user_tuning_config
self
.
_executor
=
None
self
.
_executor
=
None
...
@@ -433,7 +429,7 @@ class Engine:
...
@@ -433,7 +429,7 @@ class Engine:
break
break
train_logs
[
"step: {:d} "
]
=
step
train_logs
[
"step: {:d} "
]
=
step
if
lr_scheduler
is
not
None
:
if
lr_scheduler
is
not
None
and
step
%
self
.
k_steps
==
0
:
lr_scheduler
.
step
()
lr_scheduler
.
step
()
try
:
try
:
train_logs
[
"lr: {:5e} "
]
=
self
.
_lr_optimizer
.
get_lr
()
train_logs
[
"lr: {:5e} "
]
=
self
.
_lr_optimizer
.
get_lr
()
...
@@ -551,6 +547,12 @@ class Engine:
...
@@ -551,6 +547,12 @@ class Engine:
epochs
=
1
,
epochs
=
1
,
steps_per_epoch
=
None
,
steps_per_epoch
=
None
,
collate_fn
=
None
):
collate_fn
=
None
):
if
self
.
strategy
.
gradient_merge
and
batch_size
is
not
None
:
assert
batch_size
%
self
.
k_steps
==
0
,
\
"Requires batch_size:[{}] to be divisible by k_steps:[{}]."
.
format
(
batch_size
,
self
.
k_steps
)
batch_size
//=
self
.
k_steps
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_startup_prog
=
self
.
_dist_startup_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_startup_prog
=
self
.
_dist_startup_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
mode
]
dist_context
=
self
.
_dist_contexts
[
self
.
mode
]
...
@@ -612,6 +614,9 @@ class Engine:
...
@@ -612,6 +614,9 @@ class Engine:
def
_validate_spec
(
self
,
specs
):
def
_validate_spec
(
self
,
specs
):
specs
=
to_list
(
specs
)
specs
=
to_list
(
specs
)
self
.
k_steps
=
1
if
self
.
strategy
.
gradient_merge
:
self
.
k_steps
=
self
.
strategy
.
gradient_merge_configs
[
'k_steps'
]
if
specs
is
not
None
:
if
specs
is
not
None
:
for
i
,
spec
in
enumerate
(
specs
):
for
i
,
spec
in
enumerate
(
specs
):
assert
isinstance
(
spec
,
InputSpec
)
assert
isinstance
(
spec
,
InputSpec
)
...
@@ -619,6 +624,12 @@ class Engine:
...
@@ -619,6 +624,12 @@ class Engine:
raise
ValueError
(
raise
ValueError
(
"Requires Input[{}].name != None, but receive `None` with {}."
"Requires Input[{}].name != None, but receive `None` with {}."
.
format
(
i
,
spec
))
.
format
(
i
,
spec
))
if
self
.
k_steps
>
1
:
shape
=
list
(
spec
.
shape
)
assert
shape
[
0
]
%
self
.
k_steps
==
0
,
\
"Requires batch_size[{}] to be divisible by k_steps[{}]."
.
format
(
spec
.
shape
[
0
],
self
.
k_steps
)
shape
[
0
]
//=
self
.
k_steps
spec
.
shape
=
shape
return
specs
return
specs
def
_is_local_var
(
self
,
var
):
def
_is_local_var
(
self
,
var
):
...
...
python/paddle/distributed/auto_parallel/parallelizer.py
浏览文件 @
3576e49c
...
@@ -84,7 +84,7 @@ class AutoParallelizer:
...
@@ -84,7 +84,7 @@ class AutoParallelizer:
self
.
_need_rank_mapping
=
os
.
getenv
(
"PADDLE_NEED_RANK_MAPPING"
)
self
.
_need_rank_mapping
=
os
.
getenv
(
"PADDLE_NEED_RANK_MAPPING"
)
self
.
_need_rank_mapping
=
True
if
self
.
_need_rank_mapping
and
\
self
.
_need_rank_mapping
=
True
if
self
.
_need_rank_mapping
and
\
self
.
_need_rank_mapping
.
lower
()
==
'true'
else
False
self
.
_need_rank_mapping
.
lower
()
==
'true'
else
False
self
.
_pass_context
=
None
#
self._pass_context = None
def
_remove_distributed_attrs
(
self
,
main_program
):
def
_remove_distributed_attrs
(
self
,
main_program
):
suffix
=
core
.
kAutoParallelSuffix
()
suffix
=
core
.
kAutoParallelSuffix
()
...
@@ -143,10 +143,11 @@ class AutoParallelizer:
...
@@ -143,10 +143,11 @@ class AutoParallelizer:
def
_apply_optimize
(
self
,
main_program
,
startup_program
,
params_grads
):
def
_apply_optimize
(
self
,
main_program
,
startup_program
,
params_grads
):
optimizer
=
copy
.
deepcopy
(
self
.
_optimizer
)
with
program_guard
(
main_program
,
startup_program
):
with
program_guard
(
main_program
,
startup_program
):
optimize_ops
=
copy
.
deepcopy
(
optimize_ops
=
optimizer
.
apply_gradients
(
params_grads
)
self
.
_optimizer
).
apply_gradients
(
params_grads
)
self
.
_dist_context
.
_lr_optimizer
=
optimizer
# update completion
# update completion
self
.
_completer
=
Completer
(
self
.
_dist_context
)
self
.
_completer
=
Completer
(
self
.
_dist_context
)
self
.
_completer
.
complete_update_annotation
(
main_program
)
self
.
_completer
.
complete_update_annotation
(
main_program
)
...
@@ -165,6 +166,15 @@ class AutoParallelizer:
...
@@ -165,6 +166,15 @@ class AutoParallelizer:
config
)
config
)
auto_parallel_sharding_pass
.
apply
([
main_program
],
[
startup_program
],
auto_parallel_sharding_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
self
.
_pass_context
)
params_grads
=
self
.
_pass_context
.
get_attr
(
"params_grads"
)
config
=
copy
.
deepcopy
(
self
.
_dist_strategy
.
sharding_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"params_grads"
]
=
params_grads
config
[
"rank_id"
]
=
rank
auto_parallel_clip_pass
=
new_pass
(
"auto_parallel_grad_clip"
,
config
)
auto_parallel_clip_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
if
self
.
_dist_strategy
.
gradient_merge
:
if
self
.
_dist_strategy
.
gradient_merge
:
config
=
copy
.
deepcopy
(
self
.
_dist_strategy
.
gradient_merge_configs
)
config
=
copy
.
deepcopy
(
self
.
_dist_strategy
.
gradient_merge_configs
)
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
3576e49c
...
@@ -230,9 +230,9 @@ class Parallelizer:
...
@@ -230,9 +230,9 @@ class Parallelizer:
config
)
config
)
auto_parallel_sharding_pass
.
apply
([
main_program
],
[
startup_program
],
auto_parallel_sharding_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
self
.
_pass_context
)
params_grads
=
self
.
_pass_context
.
get_attr
(
"params_grads"
)
# GradClip is train-only optimization
# GradClip is train-only optimization
if
self
.
_mode
==
"train"
:
if
self
.
_mode
==
"train"
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
sharding_configs
)
config
=
copy
.
deepcopy
(
self
.
_strategy
.
sharding_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"dist_context"
]
=
self
.
_dist_context
...
...
python/paddle/distributed/passes/auto_parallel_fp16.py
浏览文件 @
3576e49c
...
@@ -442,7 +442,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
...
@@ -442,7 +442,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
inputs
=
{
'X'
:
grads
,
'Scale'
:
loss_scaling
}
inputs
=
{
'X'
:
grads
,
'Scale'
:
loss_scaling
}
outputs
=
{
'Out'
:
grads
,
'FoundInfinite'
:
found_inf
}
outputs
=
{
'Out'
:
grads
,
'FoundInfinite'
:
found_inf
}
attrs
=
{
'op_role'
:
OpRole
.
Backward
}
attrs
=
{
'op_role'
:
OpRole
.
Optimize
}
new_op
=
main_block
.
append_op
(
type
=
'check_finite_and_unscale'
,
new_op
=
main_block
.
append_op
(
type
=
'check_finite_and_unscale'
,
inputs
=
inputs
,
inputs
=
inputs
,
outputs
=
outputs
,
outputs
=
outputs
,
...
@@ -575,18 +575,18 @@ class FP16Pass(AMPPass):
...
@@ -575,18 +575,18 @@ class FP16Pass(AMPPass):
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
found_infs
=
[]
found_infs
=
[]
if
fp32_grads
:
if
fp32_grads
:
with
main_program
.
_
backward_role_guard
(
):
with
main_program
.
_
optimized_guard
([]
):
_
,
found_inf_fp32
=
_check_and_update_gradient
(
_
,
found_inf_fp32
=
_check_and_update_gradient
(
fp32_grads
,
self
.
_loss_scaling
,
"@fp32"
,
fp32_grads
,
self
.
_loss_scaling
,
"@fp32"
,
self
.
dist_context
)
self
.
dist_context
)
found_infs
.
append
(
found_inf_fp32
)
found_infs
.
append
(
found_inf_fp32
)
if
fp16_grads
:
if
fp16_grads
:
with
main_program
.
_
backward_role_guard
(
):
with
main_program
.
_
optimized_guard
([]
):
_
,
found_inf_fp16
=
_check_and_update_gradient
(
_
,
found_inf_fp16
=
_check_and_update_gradient
(
fp16_grads
,
self
.
_loss_scaling
,
"@fp16"
,
fp16_grads
,
self
.
_loss_scaling
,
"@fp16"
,
self
.
dist_context
)
self
.
dist_context
)
found_infs
.
append
(
found_inf_fp16
)
found_infs
.
append
(
found_inf_fp16
)
with
main_program
.
_
backward_role_guard
(
):
with
main_program
.
_
optimized_guard
([]
):
block
=
main_program
.
global_block
()
block
=
main_program
.
global_block
()
all_infs
=
paddle
.
fluid
.
layers
.
concat
(
found_infs
)
all_infs
=
paddle
.
fluid
.
layers
.
concat
(
found_infs
)
...
@@ -608,7 +608,7 @@ class FP16Pass(AMPPass):
...
@@ -608,7 +608,7 @@ class FP16Pass(AMPPass):
block
,
self
.
dist_context
)
block
,
self
.
dist_context
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
with
main_program
.
_
backward_role_guard
(
):
with
main_program
.
_
optimized_guard
([]
):
if
fp32_grads
:
if
fp32_grads
:
self
.
_update_loss_scaling
(
fp32_grads
,
found_inf
)
self
.
_update_loss_scaling
(
fp32_grads
,
found_inf
)
if
fp16_grads
:
if
fp16_grads
:
...
...
python/paddle/distributed/passes/auto_parallel_grad_clip.py
浏览文件 @
3576e49c
...
@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase):
...
@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase):
super
(
ClipGradByGloblNormPass
,
self
).
__init__
()
super
(
ClipGradByGloblNormPass
,
self
).
__init__
()
self
.
set_attr
(
"rank_id"
,
None
)
self
.
set_attr
(
"rank_id"
,
None
)
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"params_grads"
,
None
)
def
_check_self
(
self
):
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
if
self
.
get_attr
(
"dist_context"
)
is
None
:
...
@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase):
...
@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context
=
self
.
get_attr
(
"dist_context"
)
dist_context
=
self
.
get_attr
(
"dist_context"
)
if
dist_context
.
_lr_optimizer
.
_grad_clip
is
None
:
if
dist_context
.
_lr_optimizer
.
_grad_clip
is
None
:
return
False
return
False
if
self
.
get_attr
(
"params_grads"
)
is
None
:
return
False
return
True
return
True
def
_check_conflict
(
self
,
other_pass
):
def
_check_conflict
(
self
,
other_pass
):
...
@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase):
...
@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context
=
self
.
get_attr
(
"dist_context"
,
None
)
dist_context
=
self
.
get_attr
(
"dist_context"
,
None
)
rank_id
=
self
.
get_attr
(
"rank_id"
,
None
)
rank_id
=
self
.
get_attr
(
"rank_id"
,
None
)
block
=
main_program
.
global_block
()
block
=
main_program
.
global_block
()
dist_params_grads
=
_get_params_grads
(
block
)
dist_params_grads
=
self
.
get_attr
(
"params_grads"
,
None
)
# dist_params_grads = _get_params_grads(block)
self
.
clip_helper
=
ClipHelper
(
dist_params_grads
,
rank_id
,
block
,
self
.
clip_helper
=
ClipHelper
(
dist_params_grads
,
rank_id
,
block
,
dist_context
)
dist_context
)
...
...
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
浏览文件 @
3576e49c
...
@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
...
@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
return
optimize_ops_desc
return
optimize_ops_desc
def
_remove_op_role_var
(
param
,
grad
):
op_maker
=
core
.
op_proto_and_checker_maker
op
=
grad
.
op
if
op
and
op
.
has_attr
(
op_maker
.
kOpRoleVarAttrName
()):
op
.
_remove_attr
(
op_maker
.
kOpRoleVarAttrName
())
def
_get_gm_cond_var
(
main_program
,
k_steps
,
dist_context
):
def
_get_gm_cond_var
(
main_program
,
k_steps
,
dist_context
):
main_block
=
main_program
.
global_block
()
main_block
=
main_program
.
global_block
()
# Add const var
# Add const var
...
@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
...
@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
param
.
type
!=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
param
.
type
!=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
),
"SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
),
"SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
_remove_op_role_var
(
param
,
grad
)
# {grad.name: gradient_merge_var.name} to rename opt inputs
# {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge
=
{}
grad_to_gradient_merge
=
{}
# {param: gradient_merge_var} to insert scale op and fill_constant op
# {param: gradient_merge_var} to insert scale op and fill_constant op
...
...
python/paddle/distributed/passes/auto_parallel_sharding.py
浏览文件 @
3576e49c
...
@@ -59,6 +59,7 @@ class ShardingPass(PassBase):
...
@@ -59,6 +59,7 @@ class ShardingPass(PassBase):
self
.
varname_to_sharding_info
=
{}
self
.
varname_to_sharding_info
=
{}
self
.
partial_sharding
=
False
self
.
partial_sharding
=
False
self
.
outer_dp_group
=
None
self
.
outer_dp_group
=
None
self
.
shared_params_grads
=
[]
def
_check_self
(
self
):
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
if
self
.
get_attr
(
"dist_context"
)
is
None
:
...
@@ -94,6 +95,8 @@ class ShardingPass(PassBase):
...
@@ -94,6 +95,8 @@ class ShardingPass(PassBase):
self
.
_shard_gradient_synchronization
(
main_block
)
self
.
_shard_gradient_synchronization
(
main_block
)
self
.
_shard_parameter
(
main_block
,
startup_block
)
self
.
_shard_parameter
(
main_block
,
startup_block
)
context
.
set_attr
(
"params_grads"
,
self
.
shared_params_grads
)
def
_build_sharding_groups
(
self
,
main_block
,
params_grads
):
def
_build_sharding_groups
(
self
,
main_block
,
params_grads
):
self
.
_collective_data_parallel_groups
(
main_block
)
self
.
_collective_data_parallel_groups
(
main_block
)
self
.
_build_sharding_infos
(
params_grads
)
self
.
_build_sharding_infos
(
params_grads
)
...
@@ -148,13 +151,10 @@ class ShardingPass(PassBase):
...
@@ -148,13 +151,10 @@ class ShardingPass(PassBase):
self
.
_dist_context
.
_sharding_group
=
sharding_group
self
.
_dist_context
.
_sharding_group
=
sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group
=
[
p
for
p
,
g
in
params_grads
]
assert
len
(
params_in_group
)
==
len
(
set
(
params_in_group
)),
"found duplicated param in params_grads"
sharding_info
=
ShardingInfo
(
sharding_group
,
self
.
global_rank
,
sharding_info
=
ShardingInfo
(
sharding_group
,
self
.
global_rank
,
params_
in_group
)
params_
grads
)
self
.
sharding_infos
.
append
(
sharding_info
)
self
.
sharding_infos
.
append
(
sharding_info
)
for
param
in
params_in_group
:
for
param
in
sharding_info
.
params
:
self
.
varname_to_sharding_info
[
param
.
name
]
=
sharding_info
self
.
varname_to_sharding_info
[
param
.
name
]
=
sharding_info
def
_shard_optimizer
(
self
,
main_block
,
startup_block
,
params_grads
,
def
_shard_optimizer
(
self
,
main_block
,
startup_block
,
params_grads
,
...
@@ -201,6 +201,7 @@ class ShardingPass(PassBase):
...
@@ -201,6 +201,7 @@ class ShardingPass(PassBase):
op
.
desc
.
set_output
(
'Out'
,
reversed_x
)
op
.
desc
.
set_output
(
'Out'
,
reversed_x
)
else
:
else
:
if
op
.
type
==
"check_finite_and_unscale"
:
if
op
.
type
==
"check_finite_and_unscale"
:
op_role
=
op
.
attr
(
'op_role'
)
out_name
=
op
.
output_arg_names
[
0
]
out_name
=
op
.
output_arg_names
[
0
]
out_var
=
main_block
.
vars
[
out_name
]
out_var
=
main_block
.
vars
[
out_name
]
main_block
.
_remove_op
(
idx
,
sync
=
False
)
main_block
.
_remove_op
(
idx
,
sync
=
False
)
...
@@ -212,6 +213,7 @@ class ShardingPass(PassBase):
...
@@ -212,6 +213,7 @@ class ShardingPass(PassBase):
"shape"
:
out_var
.
shape
,
"shape"
:
out_var
.
shape
,
"dtype"
:
out_var
.
dtype
,
"dtype"
:
out_var
.
dtype
,
"value"
:
0
,
"value"
:
0
,
OP_ROLE_KEY
:
op_role
,
})
})
else
:
else
:
main_block
.
_remove_op
(
idx
,
sync
=
False
)
main_block
.
_remove_op
(
idx
,
sync
=
False
)
...
@@ -313,6 +315,9 @@ class ShardingPass(PassBase):
...
@@ -313,6 +315,9 @@ class ShardingPass(PassBase):
if
varname
!=
param_name
if
varname
!=
param_name
])
])
main_block
.
_remove_op
(
idx
,
sync
=
False
)
main_block
.
_remove_op
(
idx
,
sync
=
False
)
else
:
self
.
shared_params_grads
.
append
(
self
.
_get_param_grad
(
param_name
))
for
idx
,
op
in
reversed
(
list
(
enumerate
(
startup_block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
startup_block
.
ops
))):
if
len
(
op
.
output_arg_names
)
==
1
and
op
.
output_arg_names
[
if
len
(
op
.
output_arg_names
)
==
1
and
op
.
output_arg_names
[
...
@@ -365,6 +370,13 @@ class ShardingPass(PassBase):
...
@@ -365,6 +370,13 @@ class ShardingPass(PassBase):
sharding_info
=
self
.
varname_to_sharding_info
[
param_name
]
sharding_info
=
self
.
varname_to_sharding_info
[
param_name
]
return
sharding_info
.
is_in_local_shard
(
param_name
)
return
sharding_info
.
is_in_local_shard
(
param_name
)
def
_get_param_grad
(
self
,
param_name
):
assert
param_name
in
self
.
varname_to_sharding_info
sharding_info
=
self
.
varname_to_sharding_info
[
param_name
]
p_g
=
sharding_info
.
get_param_grad
(
param_name
)
assert
p_g
is
not
None
return
p_g
def
_shard_gradient_synchronization
(
self
,
main_block
):
def
_shard_gradient_synchronization
(
self
,
main_block
):
if
self
.
stage
<
2
:
if
self
.
stage
<
2
:
...
@@ -705,9 +717,13 @@ def shard_parameters(params, group_size):
...
@@ -705,9 +717,13 @@ def shard_parameters(params, group_size):
class
ShardingInfo
(
object
):
class
ShardingInfo
(
object
):
def
__init__
(
self
,
group
,
rank
,
params
):
def
__init__
(
self
,
group
,
rank
,
params
_grads
):
self
.
group
=
group
self
.
group
=
group
self
.
params
=
params
self
.
params_grads
=
dict
([(
p
.
name
,
(
p
,
g
))
for
p
,
g
in
params_grads
])
assert
len
(
self
.
params_grads
)
==
len
(
set
(
self
.
params_grads
)),
"found duplicated param in params_grads"
self
.
params
=
[
p
for
p
,
_
in
params_grads
]
self
.
param_names
=
[
p
.
name
for
p
in
self
.
params
]
self
.
param_names
=
[
p
.
name
for
p
in
self
.
params
]
self
.
group_size
=
group
.
nranks
self
.
group_size
=
group
.
nranks
self
.
global_rank
=
rank
self
.
global_rank
=
rank
...
@@ -762,3 +778,11 @@ class ShardingInfo(object):
...
@@ -762,3 +778,11 @@ class ShardingInfo(object):
if
usage
>
0
:
if
usage
>
0
:
broadcast_vars
.
add
(
param
)
broadcast_vars
.
add
(
param
)
return
broadcast_vars
,
param_usage
return
broadcast_vars
,
param_usage
def
get_param_grad
(
self
,
param_name
):
if
not
self
.
is_in_local_shard
(
param_name
):
raise
ValueError
(
"param[{}] not in current rank."
.
format
(
param_name
))
if
param_name
not
in
self
.
params_grads
:
raise
ValueError
(
'param[{}] not in params_grads'
.
format
(
param_name
))
return
self
.
params_grads
.
get
(
param_name
,
None
)
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
浏览文件 @
3576e49c
...
@@ -178,6 +178,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
...
@@ -178,6 +178,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
preds
=
model
(
tokens
,
position_ids
,
attention_mask
)
preds
=
model
(
tokens
,
position_ids
,
attention_mask
)
criterion
=
GPTPretrainingCriterion
()
criterion
=
GPTPretrainingCriterion
()
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
1.0
)
if
kwargs
.
get
(
'optimizer'
,
None
)
==
"LarsMomentum"
:
if
kwargs
.
get
(
'optimizer'
,
None
)
==
"LarsMomentum"
:
optimizer
=
paddle
.
fluid
.
optimizer
.
LarsMomentumOptimizer
(
optimizer
=
paddle
.
fluid
.
optimizer
.
LarsMomentumOptimizer
(
...
@@ -188,7 +189,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
...
@@ -188,7 +189,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
beta1
=
0.9
,
beta1
=
0.9
,
beta2
=
0.999
,
beta2
=
0.999
,
epsilon
=
1e-08
,
epsilon
=
1e-08
,
grad_clip
=
None
)
grad_clip
=
clip
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
startup_program
=
paddle
.
static
.
default_startup_program
()
startup_program
=
paddle
.
static
.
default_startup_program
()
_
,
_
,
dist_startup_prog
,
dist_main_prog
=
optimizer
.
minimize
(
_
,
_
,
dist_startup_prog
,
dist_main_prog
=
optimizer
.
minimize
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录