Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
bec9fc9a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bec9fc9a
编写于
9月 29, 2021
作者:
W
WangXi
提交者:
GitHub
9月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid] Fix model parallel non-distributed param broadcast (#36186)
上级
f703558d
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
105 addition
and
69 deletion
+105
-69
python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py
...tributed/fleet/meta_optimizers/sharding/offload_helper.py
+31
-17
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+59
-37
python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py
...fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py
+7
-9
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+8
-6
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py
浏览文件 @
bec9fc9a
...
@@ -25,8 +25,9 @@ class OffloadHelper(object):
...
@@ -25,8 +25,9 @@ class OffloadHelper(object):
cuda_place_type
=
1
cuda_place_type
=
1
cuda_pinned_place_type
=
2
cuda_pinned_place_type
=
2
def
__init__
(
self
,
ring_id
=
None
):
def
__init__
(
self
,
mp_ring_id
=
None
,
dp_ring_id
=
None
):
self
.
ring_id
=
ring_id
self
.
mp_ring_id
=
mp_ring_id
self
.
dp_ring_id
=
dp_ring_id
def
_insert_cast_op
(
self
,
block
,
idx
,
src_name
,
dst_name
):
def
_insert_cast_op
(
self
,
block
,
idx
,
src_name
,
dst_name
):
src_var
=
block
.
var
(
src_name
)
src_var
=
block
.
var
(
src_name
)
...
@@ -49,20 +50,31 @@ class OffloadHelper(object):
...
@@ -49,20 +50,31 @@ class OffloadHelper(object):
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
def
_insert_broadcast_op
(
self
,
block
,
idx
,
param
):
def
_insert_broadcast_op
(
self
,
block
,
idx
,
param_name
):
if
self
.
ring_id
is
None
:
rings
=
[]
return
block
.
_insert_op_without_sync
(
if
self
.
dp_ring_id
is
not
None
:
idx
,
rings
.
append
(
self
.
dp_ring_id
)
type
=
"c_broadcast"
,
inputs
=
{
'X'
:
param
},
# need sync non distributed param in mp group
outputs
=
{
'Out'
:
param
},
if
self
.
mp_ring_id
is
not
None
:
attrs
=
{
param
=
block
.
var
(
param_name
)
'ring_id'
:
self
.
ring_id
,
if
not
hasattr
(
param
,
'is_distributed'
)
or
not
param
.
is_distributed
:
'root'
:
0
,
rings
.
append
(
self
.
mp_ring_id
)
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Forward
,
# the insert op order is: mp, dp
})
for
ring
in
rings
:
block
.
_insert_op_without_sync
(
idx
,
type
=
"c_broadcast"
,
inputs
=
{
'X'
:
param_name
},
outputs
=
{
'Out'
:
param_name
},
attrs
=
{
'ring_id'
:
ring
,
'root'
:
0
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Forward
,
})
def
_insert_memcpy_op
(
self
,
block
,
idx
,
src_name
,
dst_name
,
dst_place_type
):
def
_insert_memcpy_op
(
self
,
block
,
idx
,
src_name
,
dst_name
,
dst_place_type
):
src_var
=
block
.
var
(
src_name
)
src_var
=
block
.
var
(
src_name
)
...
@@ -236,7 +248,7 @@ class OffloadHelper(object):
...
@@ -236,7 +248,7 @@ class OffloadHelper(object):
self
.
_insert_cast_op
(
startup_block
,
insert_idx
,
var_name
,
self
.
_insert_cast_op
(
startup_block
,
insert_idx
,
var_name
,
param_to_fp16
[
var_name
])
param_to_fp16
[
var_name
])
# NOTE(wangxi): cast and offload should insert after broadcast param.
# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: broadcast, cast, offload
# the insert op order is:
{mp, dp}
broadcast, cast, offload
self
.
_insert_broadcast_op
(
startup_block
,
insert_idx
,
self
.
_insert_broadcast_op
(
startup_block
,
insert_idx
,
var_name
)
var_name
)
...
@@ -489,6 +501,8 @@ class OffloadHelper(object):
...
@@ -489,6 +501,8 @@ class OffloadHelper(object):
self
.
_insert_cast_op
(
startup_block
,
insert_idx
,
var_name
,
self
.
_insert_cast_op
(
startup_block
,
insert_idx
,
var_name
,
param_to_fp16
[
var_name
])
param_to_fp16
[
var_name
])
# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: {mp, dp}broadcast, cast, offload
self
.
_insert_broadcast_op
(
startup_block
,
insert_idx
,
self
.
_insert_broadcast_op
(
startup_block
,
insert_idx
,
var_name
)
var_name
)
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
bec9fc9a
...
@@ -467,14 +467,16 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -467,14 +467,16 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block
=
self
.
_main_program
.
global_block
()
main_block
=
self
.
_main_program
.
global_block
()
startup_block
=
self
.
_startup_program
.
global_block
()
startup_block
=
self
.
_startup_program
.
global_block
()
mp_ring_id
=
self
.
mp_ring_id
if
self
.
mp_degree
>
1
else
None
dp_ring_id
=
self
.
dp_ring_id
if
self
.
dp_degree
>
1
else
None
dp_ring_id
=
self
.
dp_ring_id
if
self
.
dp_degree
>
1
else
None
offload_helper
=
OffloadHelper
(
mp_ring_id
=
mp_ring_id
,
dp_ring_id
=
dp_ring_id
)
# optimize offload should be enable while gradient merge is enable and
# optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be
# overlap with calc, otherwise it will slower down training severely.
# overlap with calc, otherwise it will slower down training severely.
if
sharding_configs
[
"optimize_offload"
]:
if
sharding_configs
[
"optimize_offload"
]:
logger
.
info
(
"Sharding with optimize offload !"
)
logger
.
info
(
"Sharding with optimize offload !"
)
offload_helper
=
OffloadHelper
(
ring_id
=
dp_ring_id
)
offload_helper
.
offload
(
main_block
,
startup_block
)
offload_helper
.
offload
(
main_block
,
startup_block
)
# The optimize_cast is already included in offload_fp32param
# The optimize_cast is already included in offload_fp32param
offload_helper
.
offload_fp32param
(
main_block
,
startup_block
)
offload_helper
.
offload_fp32param
(
main_block
,
startup_block
)
...
@@ -482,7 +484,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -482,7 +484,6 @@ class ShardingOptimizer(MetaOptimizerBase):
logger
.
info
(
"Sharding with optimize cast !"
)
logger
.
info
(
"Sharding with optimize cast !"
)
# NOTE(wangxi): optimize_cast will persist fp16 param, it
# NOTE(wangxi): optimize_cast will persist fp16 param, it
# will take more memory, but will be faster. Trade space for time.
# will take more memory, but will be faster. Trade space for time.
offload_helper
=
OffloadHelper
(
ring_id
=
dp_ring_id
)
if
self
.
_optimizer_sharding
:
if
self
.
_optimizer_sharding
:
offload_helper
.
opt_sharding_cast_fp32param
(
offload_helper
.
opt_sharding_cast_fp32param
(
main_block
,
startup_block
,
main_block
,
startup_block
,
...
@@ -554,6 +555,10 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -554,6 +555,10 @@ class ShardingOptimizer(MetaOptimizerBase):
# init param broadcast should be called after startup pruning
# init param broadcast should be called after startup pruning
self
.
_initialization_broadcast
()
self
.
_initialization_broadcast
()
# NOTE(wangxi): if param is not persistable, program.clone will
# failed, so we remove no persistable param, recreate param as a var
self
.
_recreate_not_persist_param_as_var
()
self
.
_dump_program_for_debug
()
self
.
_dump_program_for_debug
()
# GPU need to wait server ready, GPU and NPU is Layered connection
# GPU need to wait server ready, GPU and NPU is Layered connection
...
@@ -1385,23 +1390,14 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -1385,23 +1390,14 @@ class ShardingOptimizer(MetaOptimizerBase):
return
return
def
_initialization_broadcast
(
self
):
def
_recreate_not_persist_param_as_var
(
self
):
"""
def
recreate_not_persist_param_as_var
(
program
):
this funtion is to ensure the initialization between dp group to be
block
=
program
.
global_block
()
identical when hybrid-dp is used.
params
=
block
.
all_parameters
()
"""
for
param
in
params
:
if
not
self
.
hybrid_dp
:
if
param
.
persistable
:
return
continue
startup_block
=
self
.
_startup_program
.
global_block
()
params
=
startup_block
.
all_parameters
()
params_name
=
[]
# NOTE(wangxi): if param is not persistable, program.clone will
# failed, so we remove no persistable param, re add param as a var
for
param
in
params
:
params_name
.
append
(
param
.
name
)
if
not
param
.
persistable
:
name
=
param
.
name
name
=
param
.
name
shape
=
param
.
shape
shape
=
param
.
shape
dtype
=
param
.
dtype
dtype
=
param
.
dtype
...
@@ -1411,15 +1407,14 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -1411,15 +1407,14 @@ class ShardingOptimizer(MetaOptimizerBase):
trainable
=
param
.
trainable
trainable
=
param
.
trainable
optimize_attr
=
param
.
optimize_attr
optimize_attr
=
param
.
optimize_attr
regularizer
=
param
.
regularizer
regularizer
=
param
.
regularizer
have_dist_attr
=
False
have_dist_attr
=
False
is_distributed
=
False
is_distributed
=
False
if
hasattr
(
param
,
'is_distributed'
):
if
hasattr
(
param
,
'is_distributed'
):
have_dist_attr
=
True
have_dist_attr
=
True
is_distributed
=
param
.
is_distributed
is_distributed
=
param
.
is_distributed
startup_
block
.
_remove_var
(
name
,
sync
=
False
)
block
.
_remove_var
(
name
,
sync
=
False
)
var
=
startup_
block
.
create_var
(
var
=
block
.
create_var
(
name
=
name
,
name
=
name
,
shape
=
shape
,
shape
=
shape
,
dtype
=
dtype
,
dtype
=
dtype
,
...
@@ -1431,6 +1426,31 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -1431,6 +1426,31 @@ class ShardingOptimizer(MetaOptimizerBase):
if
have_dist_attr
:
if
have_dist_attr
:
var
.
is_distributed
=
is_distributed
var
.
is_distributed
=
is_distributed
block
.
_sync_with_cpp
()
recreate_not_persist_param_as_var
(
self
.
_startup_program
)
recreate_not_persist_param_as_var
(
self
.
_main_program
)
def
_initialization_broadcast
(
self
):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used, and the initialization of
not distributed param between mp group to be identical.
"""
if
self
.
dp_degree
<=
1
and
self
.
mp_degree
<=
1
:
return
startup_block
=
self
.
_startup_program
.
global_block
()
params
=
startup_block
.
all_parameters
()
params_name
=
[]
not_dist_param_name
=
set
()
for
param
in
params
:
params_name
.
append
(
param
.
name
)
if
not
hasattr
(
param
,
'is_distributed'
)
or
not
param
.
is_distributed
:
not_dist_param_name
.
add
(
param
.
name
)
# offload and optimize_cast will insert broadcast op
# offload and optimize_cast will insert broadcast op
broadcast_params
=
set
()
broadcast_params
=
set
()
for
op
in
startup_block
.
ops
:
for
op
in
startup_block
.
ops
:
...
@@ -1439,23 +1459,25 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -1439,23 +1459,25 @@ class ShardingOptimizer(MetaOptimizerBase):
for
param
in
params_name
:
for
param
in
params_name
:
if
param
in
broadcast_params
:
continue
if
param
in
broadcast_params
:
continue
startup_block
.
append_op
(
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
'root'
:
0
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
startup_block
.
append_op
(
rings
=
[]
type
=
'c_sync_comm_stream'
,
# need sync not distributed param in mp group
inputs
=
{
'X'
:
params_name
},
if
self
.
mp_degree
>
1
and
param
in
not_dist_param_name
:
outputs
=
{
'Out'
:
params_name
},
rings
.
append
(
self
.
mp_ring_id
)
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
if
self
.
dp_degree
>
1
:
OP_ROLE_KEY
:
OpRole
.
Forward
})
rings
.
append
(
self
.
dp_ring_id
)
for
ring
in
rings
:
startup_block
.
append_op
(
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
ring
,
'root'
:
0
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
startup_block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
...
...
python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py
浏览文件 @
bec9fc9a
...
@@ -72,8 +72,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -72,8 +72,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
'c_sync_comm_stream'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -155,8 +154,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -155,8 +154,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
'c_sync_comm_stream'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -218,7 +216,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -218,7 +216,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -292,7 +290,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -292,7 +290,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -371,7 +369,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -371,7 +369,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'c_sync_comm_stream'
'cast'
,
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -460,7 +458,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
...
@@ -460,7 +458,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_comm_init'
,
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -511,7 +509,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
...
@@ -511,7 +509,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
浏览文件 @
bec9fc9a
...
@@ -655,7 +655,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -655,7 +655,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -764,7 +766,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -764,7 +766,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_broadcast'
,
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -932,7 +934,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -932,7 +934,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'cast'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -1029,7 +1031,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -1029,7 +1031,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'cast'
,
'memcpy'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -1129,7 +1131,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -1129,7 +1131,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'cast'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'cast'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -1221,7 +1223,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -1221,7 +1223,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
'c_broadcast'
,
'c_broadcast'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录