Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a4eadd15
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a4eadd15
编写于
9月 16, 2021
作者:
Y
Yuang Liu
提交者:
GitHub
9月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid] Fix mp multi gradient clip prob (#35713)
上级
4b683887
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
127 addition
and
23 deletion
+127
-23
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
...ed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
+90
-19
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+1
-2
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+9
-2
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+27
-0
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
浏览文件 @
a4eadd15
...
@@ -142,32 +142,103 @@ class GradientClipHelper(object):
...
@@ -142,32 +142,103 @@ class GradientClipHelper(object):
return
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def
sync_global_norm
(
self
,
block
,
ring_ids
):
def
sync_global_norm
(
self
,
block
,
ring_ids
,
mp_rank
):
"""
"""
prune gradient_clip related ops for params that not belong to cur shard
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
keep: sum, sqrt, elementwise_max, elementwise_div
"""
"""
# FIXME(wangxi): mp should prune duplicated param_grads
is_clip_grad_by_global_norm
=
False
for
idx
,
op
in
list
(
enumerate
(
block
.
ops
)):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
is_clip_grad_by_global_norm
=
True
break
if
not
is_clip_grad_by_global_norm
:
# TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp
return
removed_op_idx
=
set
()
removed_tmp_var
=
set
()
for
idx
,
op
in
list
(
enumerate
(
block
.
ops
)):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
break
for
input_name
in
op
.
input_arg_names
:
input_var
=
block
.
var
(
input_name
)
# NOTE: when mp_degree > 1, some vars will be split into each mp rank.
# However, there still some vars such as Scale, Bias are not split.
# Those not be split vars should only be counted once during grad clip
# by global norm. Those vars either doesn't have is_distributed attr
# or the is_distributed attr has been set as False.
# Therefore, we prune those duplicated vars for grad clip.
if
mp_rank
>=
1
and
(
not
(
hasattr
(
input_var
,
'is_distributed'
)
and
input_var
.
is_distributed
)):
removed_op_idx
.
add
(
idx
)
for
output_name
in
op
.
output_arg_names
:
removed_tmp_var
.
add
(
output_name
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_gradient_clip_op
(
op
):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
continue
if
idx
in
removed_op_idx
:
block
.
_remove_op
(
idx
,
sync
=
False
)
if
op
.
type
==
"sum"
:
for
var_name
in
removed_tmp_var
:
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
block
.
_remove_var
(
var_name
,
sync
=
False
)
for
idx
,
op
in
list
(
enumerate
(
block
.
ops
)):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
# If mp_rank == 0, no extra handles, just allreduce
# If mp_rank >= 1, some extra handles is needed
sum_rst_var
=
block
.
var
(
op
.
output_arg_names
[
0
])
if
mp_rank
>=
1
:
reserved_vars
=
[]
for
input_name
in
op
.
input_arg_names
:
if
input_name
not
in
removed_tmp_var
:
reserved_vars
.
append
(
input_name
)
if
len
(
reserved_vars
)
>
0
:
op
.
desc
.
set_input
(
"X"
,
reserved_vars
)
else
:
# If all input of sum op should be removed, then remove the sum op.
# And set the output's value of sum to 0.
namescope
=
op
.
attr
(
"op_namescope"
)
block
.
_remove_op
(
idx
,
sync
=
False
)
fill_constant_op
=
block
.
_insert_op_without_sync
(
idx
,
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
sum_rst_var
},
attrs
=
{
'shape'
:
sum_rst_var
.
shape
,
'dtype'
:
sum_rst_var
.
dtype
,
'value'
:
0.0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
fill_constant_op
.
_set_attr
(
'op_namescope'
,
namescope
)
self
.
_insert_allreduce
(
block
,
ring_ids
,
idx
,
sum_rst_var
)
break
@
staticmethod
def
_insert_allreduce
(
block
,
ring_ids
,
idx
,
var
):
for
ring_id
in
ring_ids
:
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
if
ring_id
==
-
1
:
continue
idx
=
idx
+
1
idx
=
idx
+
1
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
idx
,
idx
,
type
=
'c_allreduce_sum'
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
inputs
=
{
'X'
:
var
},
outputs
=
{
'Out'
:
sum_res
},
outputs
=
{
'Out'
:
var
},
attrs
=
{
attrs
=
{
'ring_id'
:
ring_id
,
'ring_id'
:
ring_id
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
})
return
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
a4eadd15
...
@@ -435,7 +435,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -435,7 +435,6 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block
=
self
.
_main_program
.
global_block
()
main_block
=
self
.
_main_program
.
global_block
()
startup_block
=
self
.
_startup_program
.
global_block
()
startup_block
=
self
.
_startup_program
.
global_block
()
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
# amp inf_var & clip global_norm_var
rings
=
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
rings
=
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
...
@@ -446,7 +445,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -446,7 +445,7 @@ class ShardingOptimizer(MetaOptimizerBase):
gradientclip_helper
=
GradientClipHelper
(
None
)
gradientclip_helper
=
GradientClipHelper
(
None
)
gradientclip_helper
.
sync_global_norm
(
gradientclip_helper
.
sync_global_norm
(
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
])
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
,
self
.
mp_rank
)
def
_insert_loss_grad_scale_op
(
self
):
def
_insert_loss_grad_scale_op
(
self
):
main_block
=
self
.
_main_program
.
global_block
()
main_block
=
self
.
_main_program
.
global_block
()
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
a4eadd15
...
@@ -4381,7 +4381,7 @@ class PipelineOptimizer(object):
...
@@ -4381,7 +4381,7 @@ class PipelineOptimizer(object):
persistable
=
source_var
.
persistable
)
persistable
=
source_var
.
persistable
)
else
:
else
:
dest_var
=
block
.
_clone_variable
(
source_var
,
False
)
dest_var
=
block
.
_clone_variable
(
source_var
,
False
)
dest_var
.
stop_gradient
=
source_var
.
stop_gradient
self
.
_clone_var_attr
(
dest_var
,
source_var
)
# When use with sharding, allreduce_sum and allreduce_max
# When use with sharding, allreduce_sum and allreduce_max
# used for global gradient clip and amp will be added by sharding.
# used for global gradient clip and amp will be added by sharding.
op_idx
+=
1
op_idx
+=
1
...
@@ -4547,9 +4547,14 @@ class PipelineOptimizer(object):
...
@@ -4547,9 +4547,14 @@ class PipelineOptimizer(object):
persistable
=
ref_var
.
persistable
,
persistable
=
ref_var
.
persistable
,
is_data
=
ref_var
.
is_data
,
is_data
=
ref_var
.
is_data
,
need_check_feed
=
ref_var
.
desc
.
need_check_feed
())
need_check_feed
=
ref_var
.
desc
.
need_check_feed
())
new_var
.
stop_gradient
=
ref_var
.
stop_gradient
self
.
_clone_var_attr
(
new_var
,
ref_var
)
return
new_var
return
new_var
def
_clone_var_attr
(
self
,
dest
,
src
):
dest
.
stop_gradient
=
src
.
stop_gradient
if
hasattr
(
src
,
'is_distributed'
):
dest
.
is_distributed
=
src
.
is_distributed
def
_strip_grad_suffix
(
self
,
name
):
def
_strip_grad_suffix
(
self
,
name
):
"""
"""
Strip the grad suffix from the given variable name
Strip the grad suffix from the given variable name
...
@@ -5209,6 +5214,8 @@ class PipelineOptimizer(object):
...
@@ -5209,6 +5214,8 @@ class PipelineOptimizer(object):
persistable
=
True
,
persistable
=
True
,
stop_gradient
=
False
)
stop_gradient
=
False
)
real_param
=
main_block
.
var
(
param
)
real_param
=
main_block
.
var
(
param
)
if
hasattr
(
real_param
,
'is_distributed'
):
merged_grad_var
.
is_distributed
=
real_param
.
is_distributed
tmp_size
=
self
.
_get_var_size
(
real_grad
)
tmp_size
=
self
.
_get_var_size
(
real_grad
)
# two strategies for splitting the grad
# two strategies for splitting the grad
# 1. the current segment's size reach the user defined grad_size_in_MB
# 1. the current segment's size reach the user defined grad_size_in_MB
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
浏览文件 @
a4eadd15
...
@@ -658,6 +658,33 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -658,6 +658,33 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id'
,
'c_comm_init'
'c_gen_nccl_id'
,
'c_comm_init'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
'partial_recv'
,
'partial_allgather'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'softmax'
,
'cast'
,
'cross_entropy2'
,
'mean'
,
'elementwise_mul'
,
'fill_constant'
,
'elementwise_mul_grad'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'cast'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'c_sync_calc_stream'
,
'partial_send'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'c_sync_comm_stream'
,
'check_finite_and_unscale'
,
'cast'
,
'c_allreduce_max'
,
'c_allreduce_max'
,
'cast'
,
'update_loss_scaling'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'sqrt'
,
'fill_constant'
,
'elementwise_max'
,
'elementwise_div'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
])
# pp + mp, partial send recv
# pp + mp, partial send recv
self
.
assertIn
(
'partial_recv'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_recv'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_allgather'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_allgather'
,
main_prog_op_types
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录