Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
eef0a943
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eef0a943
编写于
9月 28, 2021
作者:
W
WangXi
提交者:
GitHub
9月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid] optimizer sharding support optimize cast (#35878)
上级
d5268a6e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
440 addition
and
54 deletion
+440
-54
python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py
...tributed/fleet/meta_optimizers/sharding/offload_helper.py
+205
-8
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
...addle/distributed/fleet/meta_optimizers/sharding/utils.py
+67
-1
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+69
-18
python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py
...fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py
+76
-0
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+23
-27
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py
浏览文件 @
eef0a943
...
...
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
from
..common
import
is_optimizer_op
,
OP_ROLE_KEY
,
OpRole
,
is_update_op
from
paddle.fluid
import
core
,
unique_name
from
.shard
import
Shard
__all__
=
[]
...
...
@@ -23,11 +25,8 @@ class OffloadHelper(object):
cuda_place_type
=
1
cuda_pinned_place_type
=
2
def
__init__
(
self
):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def
__init__
(
self
,
ring_id
=
None
):
self
.
ring_id
=
ring_id
def
_insert_cast_op
(
self
,
block
,
idx
,
src_name
,
dst_name
):
src_var
=
block
.
var
(
src_name
)
...
...
@@ -50,6 +49,21 @@ class OffloadHelper(object):
OP_ROLE_KEY
:
OpRole
.
Optimize
})
def
_insert_broadcast_op
(
self
,
block
,
idx
,
param
):
if
self
.
ring_id
is
None
:
return
block
.
_insert_op_without_sync
(
idx
,
type
=
"c_broadcast"
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
self
.
ring_id
,
'root'
:
0
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Forward
,
})
def
_insert_memcpy_op
(
self
,
block
,
idx
,
src_name
,
dst_name
,
dst_place_type
):
src_var
=
block
.
var
(
src_name
)
dst_var
=
block
.
var
(
dst_name
)
...
...
@@ -206,6 +220,8 @@ class OffloadHelper(object):
# step5: startup_block add offload
visited_vars
=
set
()
# FIXME(wangxi): should insert in idx, need move comm init to the head.
insert_idx
=
len
(
startup_block
.
ops
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
startup_block
.
ops
))):
for
out_name
in
op
.
output_arg_names
:
if
out_name
in
visited_vars
:
...
...
@@ -213,13 +229,16 @@ class OffloadHelper(object):
if
out_name
in
param_name_to_offload_name
:
var_name
=
out_name
# FIXME(wangxi): offload should insert after broadcast param
if
offload
:
offload_var_name
=
param_name_to_offload_name
[
var_name
]
self
.
_insert_offload_op
(
startup_block
,
i
dx
+
1
,
self
.
_insert_offload_op
(
startup_block
,
i
nsert_idx
,
var_name
,
offload_var_name
)
self
.
_insert_cast_op
(
startup_block
,
i
dx
+
1
,
var_name
,
self
.
_insert_cast_op
(
startup_block
,
i
nsert_idx
,
var_name
,
param_to_fp16
[
var_name
])
# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: broadcast, cast, offload
self
.
_insert_broadcast_op
(
startup_block
,
insert_idx
,
var_name
)
visited_vars
.
add
(
out_name
)
...
...
@@ -303,3 +322,181 @@ class OffloadHelper(object):
block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
def
opt_sharding_cast_fp32param
(
self
,
block
,
startup_block
,
params
,
offload
=
False
):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(pout,) = adam(p)
(p_fp16) = cast(p)
broadcast(p_fp16)
"""
global_params
=
set
()
local_params
=
set
()
param_to_fp16
=
dict
()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute
=
dict
()
recompute_to_fp16
=
dict
()
def
remove_param
(
input_name
):
global_params
.
remove
(
input_name
)
if
input_name
in
local_params
:
local_params
.
remove
(
input_name
)
if
input_name
in
param_to_fp16
:
fp16_param
=
param_to_fp16
.
pop
(
input_name
)
if
fp16_param
in
fp16_param_to_recompute
:
recompute
=
fp16_param_to_recompute
.
pop
(
fp16_param
)
recompute_to_fp16
.
pop
(
recompute
)
# step1: record param
global_params
=
set
(
params
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_update_op
(
op
):
param
=
op
.
desc
.
input
(
"Param"
)[
0
]
local_params
.
add
(
param
)
# step2: remove param which can't offload and
# record param->fp16param, fp16param->recompute_var
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
is_optimizer_op
(
op
):
break
# TODO (Yuang Liu): tmp solution for fuse_grad_merge + optimize_cast
if
op
.
type
==
'coalesce_tensor'
:
continue
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
not
in
global_params
:
continue
# param which will be used by fp32 op
if
op
.
type
!=
'cast'
:
remove_param
(
input_name
)
continue
# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name
=
op
.
output_arg_names
[
0
]
if
'cast_fp16'
not
in
output_name
:
remove_param
(
input_name
)
continue
if
'subprog'
not
in
output_name
:
assert
output_name
==
input_name
+
'.cast_fp16'
assert
input_name
not
in
param_to_fp16
,
\
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16
[
input_name
]
=
output_name
else
:
# fp16-->recompute_var
assert
input_name
in
param_to_fp16
,
\
"param must first be cast to fp16"
fp16_param
=
param_to_fp16
[
input_name
]
fp16_param_to_recompute
[
fp16_param
]
=
output_name
recompute_to_fp16
[
output_name
]
=
fp16_param
param_name_to_offload_name
=
dict
()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_update_op
(
op
):
param
=
op
.
desc
.
input
(
"Param"
)[
0
]
if
param
not
in
global_params
:
continue
# step3.1: create offload_var
offload_var_name
=
self
.
_get_offload_var_name
(
param
)
param_name_to_offload_name
[
param
]
=
offload_var_name
if
offload
:
self
.
_create_offload_var
(
param
,
offload_var_name
,
[
block
,
startup_block
])
# step3.2: insert cast op and offload op
self
.
_insert_offload_op
(
block
,
idx
+
1
,
param
,
offload_var_name
)
assert
param
in
param_to_fp16
fp16_param_name
=
param_to_fp16
[
param
]
fp16_param_var
=
block
.
var
(
fp16_param_name
)
fp16_param_var
.
persistable
=
True
self
.
_insert_cast_op
(
block
,
idx
+
1
,
param
,
param_to_fp16
[
param
])
if
offload
:
# step3.3: insert fetch op
self
.
_insert_fetch_op
(
block
,
idx
,
offload_var_name
,
param
)
continue
# step3.4: remove cast op
if
op
.
type
==
'cast'
:
input_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
input_name
in
global_params
:
block
.
_remove_op
(
idx
,
sync
=
False
)
continue
# step3.5: change recompute_param to fp16_param
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
recompute_to_fp16
:
op
.
_rename_input
(
input_name
,
recompute_to_fp16
[
input_name
])
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
in
recompute_to_fp16
:
op
.
_rename_output
(
output_name
,
recompute_to_fp16
[
output_name
])
# step4: remove recompute_param
for
name
in
recompute_to_fp16
.
keys
():
block
.
_remove_var
(
name
,
sync
=
False
)
# step5: remove fp32 param which not need
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
not
in
[
'coalesce_tensor'
,
'c_broadcast'
]:
continue
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
param_to_fp16
:
op
.
_rename_input
(
input_name
,
param_to_fp16
[
input_name
])
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
in
param_to_fp16
:
op
.
_rename_output
(
output_name
,
param_to_fp16
[
output_name
])
for
param
in
global_params
:
assert
param
in
param_to_fp16
fp16_param_name
=
param_to_fp16
[
param
]
fp16_param_var
=
block
.
var
(
fp16_param_name
)
fp16_param_var
.
persistable
=
True
if
param
not
in
local_params
:
block
.
_remove_var
(
param
,
sync
=
False
)
# step6: startup_block add offload
visited_vars
=
set
()
insert_idx
=
len
(
startup_block
.
ops
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
startup_block
.
ops
))):
for
out_name
in
op
.
output_arg_names
:
if
out_name
in
visited_vars
:
continue
if
out_name
in
param_to_fp16
:
var_name
=
out_name
if
offload
:
self
.
_insert_offload_op
(
startup_block
,
idx
+
1
,
var_name
,
param_name_to_offload_name
[
var_name
])
self
.
_insert_cast_op
(
startup_block
,
insert_idx
,
var_name
,
param_to_fp16
[
var_name
])
self
.
_insert_broadcast_op
(
startup_block
,
insert_idx
,
var_name
)
if
var_name
not
in
local_params
:
param
=
startup_block
.
var
(
out_name
)
param
.
persistable
=
False
visited_vars
.
add
(
out_name
)
block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
浏览文件 @
eef0a943
...
...
@@ -14,7 +14,7 @@
import
paddle
from
paddle.fluid
import
core
,
unique_name
from
functools
import
reduce
from
paddle.distributed.fleet.meta_optimizers.common
import
is_loss_grad_op
,
is_backward_op
from
paddle.distributed.fleet.meta_optimizers.common
import
is_loss_grad_op
,
is_backward_op
,
is_optimizer_op
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
import
re
...
...
@@ -366,6 +366,24 @@ def insert_allreduce_ops(block,
class
FuseHelper
(
object
):
@
staticmethod
def
sort_vars_by_dtype
(
block
,
vars_name
):
fp32_vars
=
[]
fp16_vars
=
[]
other_vars
=
[]
for
var
in
vars_name
:
dtype
=
block
.
var
(
var
).
dtype
if
dtype
==
paddle
.
float32
:
fp32_vars
.
append
(
var
)
elif
dtype
==
paddle
.
float16
:
fp16_vars
.
append
(
var
)
else
:
other_vars
.
append
(
var
)
assert
len
(
other_vars
)
==
0
,
"only support fp32/fp16 vars for fuse"
fp32_vars
.
extend
(
fp16_vars
)
return
fp32_vars
@
staticmethod
def
get_fused_groups
(
block
,
vars_name
,
fuse_size
=
32.
):
""" coalesce tensor, get fused group """
...
...
@@ -639,6 +657,54 @@ def insert_broadcast_param_ops(block,
return
param_in_this_device
def
fuse_opt_broadcast_param_ops
(
block
,
ring_id
,
shard
,
op_role
=
OpRole
.
Optimize
,
strategy
=
None
):
"""
fuse optimizer sharding broadcast param ops
"""
if
strategy
is
None
or
not
strategy
.
fuse_all_reduce_ops
:
return
fuse_size
=
strategy
.
fuse_grad_size_in_MB
nranks
=
shard
.
worker_num
device_to_vars
=
[[]
for
_
in
range
(
nranks
)]
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
is_optimizer_op
(
op
)
or
op
.
type
!=
'c_broadcast'
:
break
var
=
op
.
input_arg_names
[
0
]
root_id
=
op
.
attr
(
'root'
)
device_to_vars
[
root_id
].
insert
(
0
,
var
)
block
.
_remove_op
(
idx
,
sync
=
False
)
insert_idx
=
idx
+
1
for
root_id
,
vars_name
in
enumerate
(
device_to_vars
):
vars_name
=
FuseHelper
.
sort_vars_by_dtype
(
block
,
vars_name
)
groups
=
FuseHelper
.
get_fused_groups
(
block
,
vars_name
,
fuse_size
)
fused_vars
,
insert_num
=
FuseHelper
.
insert_coalesce_tensor
(
block
,
insert_idx
,
groups
,
op_role
,
prefix
=
"Param"
)
for
fused_var
in
fused_vars
:
block
.
_insert_op_without_sync
(
insert_idx
+
insert_num
,
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
fused_var
},
outputs
=
{
'Out'
:
fused_var
},
attrs
=
{
'ring_id'
:
ring_id
,
'root'
:
root_id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
op_role
})
block
.
_sync_with_cpp
()
def
get_grad_device
(
grad_name
,
shard
):
assert
"@GRAD"
in
grad_name
,
"[{}] should be a grad variable."
.
format
(
grad_name
)
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
eef0a943
...
...
@@ -329,6 +329,7 @@ class ShardingOptimizer(MetaOptimizerBase):
if
self
.
pp_degree
==
1
:
return
strategy
=
self
.
user_defined_strategy
sharding_configs
=
strategy
.
sharding_configs
main_block
=
self
.
_main_program
.
global_block
()
startup_block
=
self
.
_startup_program
.
global_block
()
...
...
@@ -399,6 +400,8 @@ class ShardingOptimizer(MetaOptimizerBase):
first_optimize_op_index
+=
(
len
(
main_block
.
ops
)
-
len_of_ops
)
len_of_ops
=
len
(
main_block
.
ops
)
# NOTE(wangxi): we fused after optimize_cast
optimize_cast
=
sharding_configs
[
'optimize_cast'
]
optimizer_param
=
utils
.
insert_broadcast_param_ops
(
main_block
,
len_of_ops
,
...
...
@@ -407,10 +410,10 @@ class ShardingOptimizer(MetaOptimizerBase):
OpRole
.
Optimize
,
use_calc_stream
=
True
,
rank
=
self
.
dp_rank
,
strategy
=
strategy
)
strategy
=
None
if
optimize_cast
else
strategy
)
logger
.
info
(
"Optimizer param in this rank {}"
.
format
(
optimizer_param
))
if
not
strategy
.
fuse_grad_merge
:
if
not
strategy
.
fuse_grad_merge
and
not
optimize_cast
:
assert
len
(
accumulated_grad_names
)
==
len
(
optimizer_param
)
elif
self
.
hybrid_dp
and
self
.
hybrid_dp_mode
==
"pp_hybrid_dp"
:
insert_allreduce_ops
(
...
...
@@ -458,18 +461,20 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block
.
_sync_with_cpp
()
def
_apply_optimize_offload_pass
(
self
):
def
_apply_optimize_offload_pass
(
self
,
params_grads
):
strategy
=
self
.
user_defined_strategy
sharding_configs
=
strategy
.
sharding_configs
main_block
=
self
.
_main_program
.
global_block
()
startup_block
=
self
.
_startup_program
.
global_block
()
dp_ring_id
=
self
.
dp_ring_id
if
self
.
dp_degree
>
1
else
None
# optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be
# overlap with calc, otherwise it will slower down training severely.
if
sharding_configs
[
"optimize_offload"
]:
logger
.
info
(
"Sharding with optimize offload !"
)
offload_helper
=
OffloadHelper
()
offload_helper
=
OffloadHelper
(
ring_id
=
dp_ring_id
)
offload_helper
.
offload
(
main_block
,
startup_block
)
# The optimize_cast is already included in offload_fp32param
offload_helper
.
offload_fp32param
(
main_block
,
startup_block
)
...
...
@@ -477,8 +482,17 @@ class ShardingOptimizer(MetaOptimizerBase):
logger
.
info
(
"Sharding with optimize cast !"
)
# NOTE(wangxi): optimize_cast will persist fp16 param, it
# will take more memory, but will be faster. Trade space for time.
offload_helper
=
OffloadHelper
()
offload_helper
.
cast_fp32param_in_optimize
(
main_block
,
startup_block
)
offload_helper
=
OffloadHelper
(
ring_id
=
dp_ring_id
)
if
self
.
_optimizer_sharding
:
offload_helper
.
opt_sharding_cast_fp32param
(
main_block
,
startup_block
,
[
x
[
0
].
name
for
x
in
params_grads
])
# NOTE(wangxi): fused after optimize_cast
utils
.
fuse_opt_broadcast_param_ops
(
main_block
,
dp_ring_id
,
self
.
_shard
,
strategy
=
strategy
)
else
:
offload_helper
.
cast_fp32param_in_optimize
(
main_block
,
startup_block
)
def
_dump_program_for_debug
(
self
):
main_block
=
self
.
_main_program
.
global_block
()
...
...
@@ -525,7 +539,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
_insert_loss_grad_scale_op
()
# apply optimize offload or optimize cast
self
.
_apply_optimize_offload_pass
()
self
.
_apply_optimize_offload_pass
(
params_grads
)
# step6: (optional) sharding gradient merge
self
.
_sharding_gradient_merge
()
...
...
@@ -1381,17 +1395,50 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block
=
self
.
_startup_program
.
global_block
()
params
=
startup_block
.
all_parameters
()
params_name
=
[]
broadcast_params
=
[]
# NOTE(wangxi): if param is not persistable, program.clone will
# failed, so we remove no persistable param, re add param as a var
for
param
in
params
:
broadcast_params
.
append
(
param
)
# optimize_cast need broadcast fp16 param
fp16_param_name
=
param
.
name
+
'.cast_fp16'
if
startup_block
.
has_var
(
fp16_param_name
):
fp16_param
=
startup_block
.
var
(
fp16_param_name
)
broadcast_params
.
append
(
fp16_param
)
for
param
in
broadcast_params
:
params_name
.
append
(
param
.
name
)
if
not
param
.
persistable
:
name
=
param
.
name
shape
=
param
.
shape
dtype
=
param
.
dtype
type
=
param
.
type
lod_level
=
param
.
lod_level
stop_gradient
=
param
.
stop_gradient
trainable
=
param
.
trainable
optimize_attr
=
param
.
optimize_attr
regularizer
=
param
.
regularizer
have_dist_attr
=
False
is_distributed
=
False
if
hasattr
(
param
,
'is_distributed'
):
have_dist_attr
=
True
is_distributed
=
param
.
is_distributed
startup_block
.
_remove_var
(
name
,
sync
=
False
)
var
=
startup_block
.
create_var
(
name
=
name
,
shape
=
shape
,
dtype
=
dtype
,
type
=
type
,
lod_level
=
lod_level
,
stop_gradient
=
stop_gradient
,
trainable
=
trainable
,
persistable
=
False
)
if
have_dist_attr
:
var
.
is_distributed
=
is_distributed
# offload and optimize_cast will insert broadcast op
broadcast_params
=
set
()
for
op
in
startup_block
.
ops
:
if
op
.
type
==
'c_broadcast'
:
broadcast_params
.
add
(
op
.
desc
.
output_arg_names
()[
0
])
for
param
in
params_name
:
if
param
in
broadcast_params
:
continue
startup_block
.
append_op
(
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
param
},
...
...
@@ -1399,15 +1446,19 @@ class ShardingOptimizer(MetaOptimizerBase):
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
'root'
:
0
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
startup_block
.
append_op
(
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
broadcast_params
},
outputs
=
{
'Out'
:
broadcast_params
},
inputs
=
{
'X'
:
params_name
},
outputs
=
{
'Out'
:
params_name
},
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
startup_block
.
_sync_with_cpp
()
# sharding gradient merge
def
create_persistable_gradients_and_insert_merge_ops
(
self
,
main_block
,
startup_block
,
insert_idx
,
grad_names
,
shard
):
...
...
python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py
浏览文件 @
eef0a943
...
...
@@ -321,6 +321,82 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_broadcast'
])
def
test_opt_sharding_with_pp_amp_ckp_fuse_gm_optcast
(
self
):
train_prog
,
startup_prog
=
static
.
Program
(),
static
.
Program
()
avg_cost
,
strategy
=
self
.
pp_net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'pipeline'
)
self
.
set_strategy
(
strategy
,
'amp'
)
strategy
.
amp_configs
=
{
'custom_black_varnames'
:
[
'fc_6.b_0'
],
}
strategy
.
recompute
=
True
strategy
.
recompute_configs
=
{
"checkpoints"
:
[
"fc_0.tmp_2"
,
"fc_1.tmp_2"
,
"fc_2.tmp_2"
,
"fc_3.tmp_2"
]
}
strategy
.
sharding
=
True
strategy
.
sharding_configs
=
{
"sharding_degree"
:
1
,
"pp_degree"
:
2
,
"dp_degree"
:
2
,
"_dp_as_optimizer_sharding"
:
True
,
'optimize_cast'
:
True
,
}
strategy
.
fuse_all_reduce_ops
=
True
strategy
.
fuse_grad_size_in_MB
=
32
strategy
.
fuse_grad_merge
=
True
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
train_prog
=
train_prog
.
_pipeline_opt
[
'section_program'
]
startup_prog
=
startup_prog
.
_pipeline_opt
[
'startup_program'
]
# self._debug = True
self
.
debug_program
(
train_prog
,
startup_prog
)
startup_prog_ops
=
startup_prog
.
global_block
().
ops
main_prog_ops
=
train_prog
.
global_block
().
ops
# check program
startup_prog_op_types
=
[
op
.
type
for
op
in
startup_prog_ops
]
main_prog_op_types
=
[
op
.
type
for
op
in
main_prog_ops
]
# global, sharding, pp_send, pp_recv
self
.
assertEqual
(
startup_prog_op_types
,
[
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'c_sync_comm_stream'
])
self
.
assertEqual
(
main_prog_op_types
,
[
'recv_v2'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'softmax'
,
'cast'
,
'cross_entropy2'
,
'mean'
,
'elementwise_mul'
,
'coalesce_tensor'
,
'coalesce_tensor'
,
'coalesce_tensor'
,
'coalesce_tensor'
,
'coalesce_tensor'
,
'coalesce_tensor'
,
'fill_constant'
,
'elementwise_mul_grad'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'cast'
,
'softmax_grad'
,
'cast'
,
'elementwise_add_grad'
,
'cast'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'c_sync_calc_stream'
,
'send_v2'
,
'cast'
,
'sum'
,
'sum'
,
'cast'
,
'sum'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_sync_comm_stream'
,
'check_finite_and_unscale'
,
'cast'
,
'c_allreduce_max'
,
'c_allreduce_max'
,
'cast'
,
'update_loss_scaling'
,
'momentum'
,
'cast'
,
'momentum'
,
'cast'
,
'momentum'
,
'cast'
,
'momentum'
,
'momentum'
,
'cast'
,
'coalesce_tensor'
,
'c_broadcast'
,
'c_broadcast'
,
'coalesce_tensor'
,
'c_broadcast'
])
class
TestFleetHybridOptimizerBoundary
(
TestFleetMetaOptimizer
):
def
setUp
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
浏览文件 @
eef0a943
...
...
@@ -922,18 +922,17 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
# ring: mp, pp_group, pp_pair, pp_pair
self
.
assertEqual
(
startup_prog_op_types
,
[
'uniform_random'
,
'cast'
,
'fill_constant'
,
'cast'
,
'uniform_random'
,
'cast'
,
'fill_constant'
,
'cast'
,
'uniform_random'
,
'cast'
,
'fill_constant'
,
'cast'
,
'uniform_random'
,
'cast'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'c_sync_comm_stream'
])
self
.
assertEqual
(
main_prog_op_types
,
[
...
...
@@ -1019,19 +1018,17 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
# ring: mp, pp_group, pp_pair, pp_pair
self
.
assertEqual
(
startup_prog_op_types
,
[
'uniform_random'
,
'cast'
,
'memcpy'
,
'fill_constant'
,
'cast'
,
'memcpy'
,
'uniform_random'
,
'cast'
,
'memcpy'
,
'fill_constant'
,
'cast'
,
'memcpy'
,
'uniform_random'
,
'cast'
,
'memcpy'
,
'fill_constant'
,
'cast'
,
'memcpy'
,
'uniform_random'
,
'cast'
,
'memcpy'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c
_broadcast
'
,
'c_broadcast'
,
'c
_broadcast'
,
'c_broadcast'
,
'c_broadcast
'
,
'c_broadcast'
,
'c
_broadcast'
,
'c_broadcast'
,
'c_broadcast
'
,
'c_broadcast'
,
'c
_broadcast'
,
'c_broadcast'
,
'c_broadcast
'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c
ast'
,
'memcpy
'
,
'c_broadcast'
,
'c
ast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy
'
,
'c_broadcast'
,
'c
ast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy
'
,
'c_broadcast'
,
'c
ast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy
'
,
'c_broadcast'
,
'c_sync_comm_stream'
])
...
...
@@ -1122,18 +1119,17 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
# ring: mp, pp_group, pp_pair, pp_pair
self
.
assertEqual
(
startup_prog_op_types
,
[
'uniform_random'
,
'cast'
,
'fill_constant'
,
'cast'
,
'uniform_random'
,
'cast'
,
'fill_constant'
,
'cast'
,
'uniform_random'
,
'cast'
,
'fill_constant'
,
'cast'
,
'uniform_random'
,
'cast'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'c_sync_comm_stream'
])
self
.
assertEqual
(
main_prog_op_types
,
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录