Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a9f877ff
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a9f877ff
编写于
7月 24, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
7月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[sharding stage 1 optim] Sharding comm overlap with backward (#55598)
上级
b10b899c
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
307 addition
and
158 deletion
+307
-158
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+2
-0
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
...ptimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
+37
-13
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+7
-5
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
.../paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
+0
-124
python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
...on/paddle/distributed/fleet/utils/tensor_fusion_helper.py
+258
-15
test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py
...ctive/fleet/hybrid_parallel_sharding_model_with_fusion.py
+2
-0
test/legacy_test/test_fused_comm_buffer.py
test/legacy_test/test_fused_comm_buffer.py
+1
-1
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
a9f877ff
...
...
@@ -68,6 +68,8 @@ message PpConfig {
message
DygraphShardingConfig
{
optional
bool
tensor_fusion
=
1
[
default
=
false
];
optional
int32
accumulate_steps
=
2
[
default
=
1
];
optional
bool
comm_overlap
=
3
[
default
=
false
];
}
message
HybridConfig
{
...
...
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
浏览文件 @
a9f877ff
...
...
@@ -78,12 +78,23 @@ class DygraphShardingOptimizer:
self
.
tensor_fusion
=
strategy
.
hybrid_configs
[
'sharding_configs'
].
tensor_fusion
self
.
accumulate_steps
=
strategy
.
hybrid_configs
[
'sharding_configs'
].
accumulate_steps
self
.
comm_overlap
=
strategy
.
hybrid_configs
[
'sharding_configs'
].
comm_overlap
pp_overlap
=
strategy
.
hybrid_configs
[
'pp_configs'
].
sharding_comm_overlap
if
self
.
tensor_fusion
:
if
self
.
tensor_fusion
or
self
.
comm_overlap
:
assert
(
not
pp_overlap
),
"Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time."
self
.
_use_main_grad
=
hasattr
(
self
.
_parameter_list
[
0
],
"main_grad"
)
self
.
_rank2decay
=
{}
self
.
_rank2fused
=
{}
self
.
_comm_buffers
=
[]
self
.
_rank2params
=
self
.
_partition_parameters
()
self
.
_param2rank
=
self
.
_map_param_to_rank
()
...
...
@@ -95,25 +106,22 @@ class DygraphShardingOptimizer:
'_param_groups'
,
self
.
_rank2params
[
self
.
_sharding_rank
]
)
else
:
self
.
_use_main_grad
=
hasattr
(
self
.
_parameter_list
[
0
],
"main_grad"
)
self
.
_rank2decay
=
{}
self
.
_rank2fused
=
{}
self
.
_tensor_fusion
()
decay_params
=
[
p
.
name
for
p
in
self
.
_rank2decay
[
self
.
_sharding_rank
]
]
all
_params
=
self
.
_rank2fused
[
self
.
_sharding_rank
]
fused
_params
=
self
.
_rank2fused
[
self
.
_sharding_rank
]
apply_decay_param_fun
=
lambda
x
:
x
in
decay_params
params
=
[]
all_fused_
params
=
[]
for
v
in
self
.
_rank2fused
.
values
():
params
+=
v
self
.
_parameter_list
=
params
self
.
_param_groups
=
params
all_fused_
params
+=
v
self
.
_parameter_list
=
all_fused_
params
self
.
_param_groups
=
all_fused_
params
self
.
_set_inner_opt_attr
(
'_parameter_list'
,
all
_params
)
self
.
_set_inner_opt_attr
(
'_param_groups'
,
all
_params
)
self
.
_set_inner_opt_attr
(
'_parameter_list'
,
fused
_params
)
self
.
_set_inner_opt_attr
(
'_param_groups'
,
fused
_params
)
origin_decay_param_fun
=
getattr
(
self
.
_inner_opt
,
'_apply_decay_param_fun'
,
None
)
...
...
@@ -145,11 +153,23 @@ class DygraphShardingOptimizer:
p
.
clear_gradient
(
set_to_zero
)
def
_tensor_fusion
(
self
):
comm_group
=
self
.
_hcg
.
get_sharding_parallel_group
()
for
i
in
range
(
self
.
_sharding_world_size
):
params
=
self
.
_rank2params
[
i
]
decay_fused
,
all_fused
=
fused_parameters
(
params
,
self
.
_use_main_grad
dst
=
comm_group
.
ranks
[
i
]
# TODO(sharding dev): make scale_after_comm a field to be configured by user
decay_fused
,
all_fused
,
all_buffer
=
fused_parameters
(
params
,
use_main_grad
=
self
.
_use_main_grad
,
fuse_param
=
True
,
comm_overlap
=
self
.
comm_overlap
,
comm_group
=
comm_group
,
dst
=
dst
,
acc_step
=
self
.
accumulate_steps
,
scale_after_comm
=
False
,
)
if
self
.
comm_overlap
:
self
.
_comm_buffers
+=
all_buffer
self
.
_rank2decay
[
i
]
=
decay_fused
self
.
_rank2fused
[
i
]
=
all_fused
for
p
in
all_fused
:
...
...
@@ -199,6 +219,10 @@ class DygraphShardingOptimizer:
def
reduce_gradients
(
self
,
parameter_list
,
hcg
):
# TODO merge grad / nrank with dp
logger
.
debug
(
"sharding start gradients sync"
)
if
self
.
comm_overlap
:
for
buffer
in
self
.
_comm_buffers
:
buffer
.
scale_grads
()
return
with
framework
.
no_grad
():
sharding_nrank
=
hcg
.
get_sharding_parallel_group
().
nranks
for
param
in
parameter_list
:
...
...
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
a9f877ff
...
...
@@ -37,11 +37,11 @@ else:
from
.pp_utils
import
p2p_communication
as
p2p
from
paddle.distributed.fleet.utils.tensor_fusion_helper
import
(
HOOK_ACTION
,
FusedCommBuffer
,
assign_group_by_size
,
)
from
.pp_utils.utils
import
HOOK_ACTION
,
FusedCommBuffer
__all__
=
[]
g_shard_use_reduce
=
int
(
os
.
environ
.
get
(
"FLAGS_shard_use_reduce"
,
0
))
...
...
@@ -334,9 +334,11 @@ class PipelineParallel(MetaParallelBase):
for
dst
in
fused_parameter_group
:
parameter_list
=
fused_parameter_group
[
dst
]
if
not
dp
:
if
act
!=
HOOK_ACTION
.
ALL_REDUCE
:
# parse the relative dst rank to absolute dst rank for sharding
dst
=
comm_group
.
ranks
[
dst
]
else
:
dst
=
-
1
var_groups
=
assign_group_by_size
(
parameter_list
)
for
group_idx
,
parameters
in
var_groups
.
items
():
buffer
=
FusedCommBuffer
(
...
...
@@ -515,7 +517,7 @@ class PipelineParallel(MetaParallelBase):
if
self
.
_comm_overlap
:
assert
len
(
self
.
_comm_buffers
)
>
0
for
buffer
in
self
.
_comm_buffers
:
buffer
.
scale_
and_split_
grads
()
buffer
.
scale_grads
()
if
self
.
_enable_timer
:
self
.
timers
(
"allreduce_shared_weight_gradients"
).
start
()
...
...
@@ -1256,7 +1258,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if
self
.
_comm_overlap
:
assert
len
(
self
.
_comm_buffers
)
>
0
for
buffer
in
self
.
_comm_buffers
:
buffer
.
scale_
and_split_
grads
()
buffer
.
scale_grads
()
if
static_scheduler
:
self
.
_reset_counter
()
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
浏览文件 @
a9f877ff
...
...
@@ -15,19 +15,10 @@
import
paddle
from
paddle
import
_legacy_C_ops
from
paddle.distributed.fleet.utils.tensor_fusion_helper
import
(
flatten_dense_tensors
,
)
from
paddle.framework
import
base
as
imperative_base
__all__
=
[]
class
HOOK_ACTION
:
ALL_REDUCE
=
0
REDUCE
=
1
FLOAT_TYPE_DICT
=
{
paddle
.
float16
:
"float16"
,
paddle
.
float32
:
"float32"
,
...
...
@@ -116,118 +107,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True):
'nranks'
,
nranks
,
)
class
FusedCommBuffer
:
def
__init__
(
self
,
id
,
params
,
comm_group
,
acc_steps
=
1
,
act
=
None
,
dst
=-
1
):
self
.
_id
=
id
self
.
_params
=
params
self
.
_acc_steps
=
acc_steps
self
.
_comm_group
=
comm_group
self
.
use_main_grad
=
hasattr
(
self
.
_params
[
0
],
"main_grad"
)
self
.
_task
=
None
self
.
_params_step_dict
=
{}
self
.
_params_checked_in
=
0
self
.
_params_to_addr
=
{}
self
.
_act
=
act
if
self
.
_act
==
HOOK_ACTION
.
ALL_REDUCE
:
assert
dst
==
-
1
elif
self
.
_act
==
HOOK_ACTION
.
REDUCE
:
assert
dst
!=
-
1
else
:
raise
ValueError
(
"The act should be allreudce for dp or reduce for sharding."
)
self
.
_dst
=
dst
self
.
_init_step_dict
()
self
.
grad_storage
=
flatten_dense_tensors
(
self
.
_params
,
use_main_grad
=
self
.
use_main_grad
,
fuse_param
=
False
,
warp_buffer
=
False
,
).
buffer
self
.
_record_addr
()
def
_record_addr
(
self
):
for
param
in
self
.
_params
:
addr
=
(
param
.
main_grad
.
data_ptr
()
if
self
.
use_main_grad
else
param
.
grad
.
data_ptr
()
)
self
.
_params_to_addr
[
param
.
name
]
=
addr
def
_init_step_dict
(
self
):
for
p
in
self
.
_params
:
self
.
_params_step_dict
[
p
.
name
]
=
0
def
_reset_params_checked_in
(
self
):
self
.
_task
=
None
self
.
_init_step_dict
()
self
.
_params_checked_in
=
0
@
property
def
_all_params_checked_in
(
self
):
return
(
len
(
self
.
_params
)
==
self
.
_params_checked_in
and
len
(
self
.
_params_step_dict
)
==
0
)
def
add_grad
(
self
,
param
):
assert
param
.
name
in
self
.
_params_step_dict
current_ptr
=
(
param
.
main_grad
.
data_ptr
()
if
self
.
use_main_grad
else
param
.
grad
.
data_ptr
()
)
if
self
.
_params_to_addr
[
param
.
name
]
!=
current_ptr
:
raise
ValueError
(
"The address of the grad/main_grad of the param has been changed during training, "
"which is not allowed for dp/sharding overlap with pp. "
"This may be caused by some non-inplace operations on the grad/main_grad. "
"Please use the inplace version of the operations or disable the overlapping."
)
self
.
_params_step_dict
[
param
.
name
]
+=
1
if
self
.
_params_step_dict
[
param
.
name
]
==
self
.
_acc_steps
:
self
.
_params_checked_in
+=
1
self
.
_params_step_dict
.
pop
(
param
.
name
)
if
self
.
_all_params_checked_in
:
self
.
_comm_grads
()
@
imperative_base
.
no_grad
def
_comm_grads
(
self
):
assert
self
.
_all_params_checked_in
if
self
.
_act
==
HOOK_ACTION
.
ALL_REDUCE
:
task
=
paddle
.
distributed
.
all_reduce
(
self
.
grad_storage
,
group
=
self
.
_comm_group
,
sync_op
=
False
)
elif
self
.
_act
==
HOOK_ACTION
.
REDUCE
:
task
=
paddle
.
distributed
.
reduce
(
self
.
grad_storage
,
dst
=
self
.
_dst
,
group
=
self
.
_comm_group
,
sync_op
=
False
,
)
self
.
_task
=
task
@
imperative_base
.
no_grad
def
scale_and_split_grads
(
self
):
assert
self
.
_task
is
not
None
self
.
_task
.
wait
()
scale_factor
=
1.0
/
self
.
_comm_group
.
nranks
self
.
grad_storage
.
scale_
(
scale_factor
)
self
.
_reset_params_checked_in
()
python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
浏览文件 @
a9f877ff
...
...
@@ -12,13 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
itertools
import
os
from
collections
import
OrderedDict
import
numpy
as
np
import
paddle
from
paddle.framework
import
base
as
imperative_base
from
paddle.framework
import
core
class
HOOK_ACTION
:
ALL_REDUCE
=
0
REDUCE
=
1
alignment
=
{
"gpu"
:
256
,
}
...
...
@@ -101,23 +109,204 @@ def flatten_dense_tensors(
return
grad_storage
def
obtain_storage
(
parameters
,
use_main_grad
,
clip
,
dist
):
def
bw_hook_func
(
buffer
,
param
):
@
paddle
.
autograd
.
no_grad
()
def
fused_comm
(
*
_
):
buffer
.
add_grad
(
param
)
return
fused_comm
class
FusedCommBuffer
:
def
__init__
(
self
,
id
,
params
,
comm_group
,
acc_steps
=
1
,
act
=
None
,
dst
=-
1
,
use_main_grad
=
None
,
fuse_param
=
False
,
scale_after_comm
=
True
,
):
self
.
_id
=
id
self
.
_params
=
params
self
.
_acc_steps
=
acc_steps
self
.
_comm_group
=
comm_group
self
.
_scale_after_comm
=
scale_after_comm
self
.
_fuse_param
=
fuse_param
self
.
use_main_grad
=
(
use_main_grad
if
use_main_grad
is
not
None
else
hasattr
(
self
.
_params
[
0
],
"main_grad"
)
)
self
.
_task
=
None
self
.
_params_step_dict
=
{}
self
.
_params_checked_in
=
0
self
.
_grads_to_addr
=
{}
self
.
_act
=
act
if
self
.
_act
==
HOOK_ACTION
.
ALL_REDUCE
:
assert
dst
==
-
1
elif
self
.
_act
==
HOOK_ACTION
.
REDUCE
:
assert
dst
!=
-
1
else
:
raise
ValueError
(
"The act should be allreudce for dp or reduce for sharding."
)
self
.
_dst
=
dst
self
.
_init_step_dict
()
if
self
.
_fuse_param
:
self
.
param_storage
,
self
.
grad_storage
=
flatten_dense_tensors
(
self
.
_params
,
use_main_grad
=
use_main_grad
,
fuse_param
=
True
,
warp_buffer
=
True
,
)
self
.
param_storage
=
self
.
param_storage
.
buffer
self
.
grad_storage
=
self
.
grad_storage
.
buffer
else
:
self
.
param_storage
=
None
self
.
grad_storage
=
flatten_dense_tensors
(
self
.
_params
,
use_main_grad
=
self
.
use_main_grad
,
fuse_param
=
False
,
warp_buffer
=
False
,
).
buffer
self
.
_record_addr
()
def
_record_addr
(
self
):
for
param
in
self
.
_params
:
addr
=
(
param
.
main_grad
.
data_ptr
()
if
self
.
use_main_grad
else
param
.
grad
.
data_ptr
()
)
self
.
_grads_to_addr
[
param
.
name
]
=
addr
def
_init_step_dict
(
self
):
for
p
in
self
.
_params
:
self
.
_params_step_dict
[
p
.
name
]
=
0
def
_reset_params_checked_in
(
self
):
self
.
_task
=
None
self
.
_init_step_dict
()
self
.
_params_checked_in
=
0
@
property
def
_all_params_checked_in
(
self
):
return
(
len
(
self
.
_params
)
==
self
.
_params_checked_in
and
len
(
self
.
_params_step_dict
)
==
0
)
def
add_grad
(
self
,
param
):
assert
param
.
name
in
self
.
_params_step_dict
current_ptr
=
(
param
.
main_grad
.
data_ptr
()
if
self
.
use_main_grad
else
param
.
grad
.
data_ptr
()
)
if
self
.
_grads_to_addr
[
param
.
name
]
!=
current_ptr
:
raise
ValueError
(
"The address of the grad/main_grad of the param has been changed during training, "
"which is not allowed for dp/sharding overlap with pp. "
"This may be caused by some non-inplace operations on the grad/main_grad. "
"Please use the inplace version of the operations or disable the overlapping."
)
self
.
_params_step_dict
[
param
.
name
]
+=
1
if
self
.
_params_step_dict
[
param
.
name
]
==
self
.
_acc_steps
:
self
.
_params_checked_in
+=
1
self
.
_params_step_dict
.
pop
(
param
.
name
)
if
self
.
_all_params_checked_in
:
self
.
_comm_grads
()
@
imperative_base
.
no_grad
def
_comm_grads
(
self
):
assert
self
.
_all_params_checked_in
if
not
self
.
_scale_after_comm
:
scale_factor
=
1.0
/
self
.
_comm_group
.
nranks
self
.
grad_storage
.
scale_
(
scale_factor
)
if
self
.
_act
==
HOOK_ACTION
.
ALL_REDUCE
:
task
=
paddle
.
distributed
.
all_reduce
(
self
.
grad_storage
,
group
=
self
.
_comm_group
,
sync_op
=
False
)
elif
self
.
_act
==
HOOK_ACTION
.
REDUCE
:
task
=
paddle
.
distributed
.
reduce
(
self
.
grad_storage
,
dst
=
self
.
_dst
,
group
=
self
.
_comm_group
,
sync_op
=
False
,
)
self
.
_task
=
task
@
imperative_base
.
no_grad
def
scale_grads
(
self
):
assert
self
.
_task
is
not
None
self
.
_task
.
wait
()
if
self
.
_scale_after_comm
:
scale_factor
=
1.0
/
self
.
_comm_group
.
nranks
self
.
grad_storage
.
scale_
(
scale_factor
)
self
.
_reset_params_checked_in
()
def
obtain_storage
(
parameters
,
use_main_grad
=
False
,
clip
=
True
,
dist
=
False
,
fuse_param
=
True
,
comm_overlap
=
False
,
act
=
None
,
comm_group
=
None
,
dst
=-
1
,
acc_steps
=
1
,
scale_after_comm
=
False
,
):
if
len
(
parameters
)
<
1
:
return
[]
return
[]
,
[]
var_groups
=
assign_group_by_size
(
parameters
,
group_size
=
256
*
1024
*
1024
)
storage
=
[]
buffers
=
[]
for
group_idx
,
parameters
in
var_groups
.
items
():
param_storage
,
grad_storage
=
flatten_dense_tensors
(
comm_buffer
=
FusedCommBuffer
(
group_idx
,
parameters
,
comm_group
=
comm_group
,
acc_steps
=
acc_steps
,
act
=
act
,
dst
=
dst
,
use_main_grad
=
use_main_grad
,
fuse_param
=
True
,
warp_buffer
=
True
,
fuse_param
=
fuse_param
,
scale_after_comm
=
scale_after_comm
,
)
param_storage
.
buffer
.
need_clip
=
clip
param_storage
.
buffer
.
is_distributed
=
dist
storage
.
append
(
param_storage
.
buffer
)
return
storage
if
fuse_param
:
param_buffer
=
comm_buffer
.
param_storage
param_buffer
.
need_clip
=
clip
param_buffer
.
is_distributed
=
dist
storage
.
append
(
param_buffer
)
if
comm_overlap
:
for
param
in
parameters
:
param
.
_register_backward_hook
(
bw_hook_func
(
comm_buffer
,
param
))
buffers
.
append
(
comm_buffer
)
return
storage
,
buffers
def
filter_params
(
params
,
is_fp32
,
is_distributed
,
need_clip
):
...
...
@@ -155,7 +344,38 @@ def filter_params(params, is_fp32, is_distributed, need_clip):
return
params
,
dtype
def
fused_parameters
(
parameters
,
use_main_grad
):
def
fused_parameters
(
parameters
,
use_main_grad
=
False
,
fuse_param
=
True
,
comm_overlap
=
False
,
comm_group
=
None
,
dst
=-
1
,
acc_step
=
1
,
scale_after_comm
=
False
,
):
"""
Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled.
:param parameters: all parameters to be fused.
:param use_main_grad: does the gradient use main grad or not
:param comm_overlap: enable comm overlap or not
:param comm_group: the comm group for comm overlap
:param dst: the dst for comm overlap
:param acc_step: acc steps, using for comm overlap
:param fuse_param: fuse param or not
:param scale_after_comm: if enable comm overlap, specify the location of grad scale
:return: param storage if fused, comm buffers is comm overlap
"""
g_shard_use_reduce
=
int
(
os
.
environ
.
get
(
"FLAGS_shard_use_reduce"
,
0
))
act
=
(
HOOK_ACTION
.
ALL_REDUCE
if
not
g_shard_use_reduce
else
HOOK_ACTION
.
REDUCE
)
if
comm_overlap
:
assert
comm_group
is
not
None
if
act
==
HOOK_ACTION
.
REDUCE
:
assert
dst
!=
-
1
elif
act
==
HOOK_ACTION
.
ALL_REDUCE
:
dst
=
-
1
param_groups
=
[]
attrs
=
[]
...
...
@@ -178,6 +398,7 @@ def fused_parameters(parameters, use_main_grad):
decay_fused
=
[]
all_fused
=
[]
all_buffers
=
[]
for
params
,
attr
in
zip
(
param_groups
,
attrs
):
decay_params
=
[]
other_params
=
[]
...
...
@@ -190,14 +411,36 @@ def fused_parameters(parameters, use_main_grad):
is_distributed
=
attr
[
1
]
need_clip
=
attr
[
2
]
decay
=
obtain_storage
(
decay_params
,
use_main_grad
,
need_clip
,
is_distributed
decay
,
decay_buffers
=
obtain_storage
(
decay_params
,
use_main_grad
=
use_main_grad
,
clip
=
need_clip
,
dist
=
is_distributed
,
fuse_param
=
fuse_param
,
comm_overlap
=
comm_overlap
,
act
=
act
,
comm_group
=
comm_group
,
dst
=
dst
,
acc_steps
=
acc_step
,
scale_after_comm
=
scale_after_comm
,
)
other
=
obtain_storage
(
other_params
,
use_main_grad
,
need_clip
,
is_distributed
other
,
other_buffers
=
obtain_storage
(
other_params
,
fuse_param
=
fuse_param
,
comm_overlap
=
comm_overlap
,
use_main_grad
=
use_main_grad
,
clip
=
need_clip
,
dist
=
is_distributed
,
act
=
act
,
comm_group
=
comm_group
,
dst
=
dst
,
acc_steps
=
acc_step
,
scale_after_comm
=
scale_after_comm
,
)
decay_fused
+=
decay
all_fused
+=
decay
all_fused
+=
other
all_buffers
+=
decay_buffers
all_buffers
+=
other_buffers
return
decay_fused
,
all_fused
return
decay_fused
,
all_fused
,
all_buffers
test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py
浏览文件 @
a9f877ff
...
...
@@ -99,6 +99,8 @@ class TestDistSharding(unittest.TestCase):
"pp_degree"
:
1
,
}
self
.
strategy
.
hybrid_configs
[
"sharding_configs"
].
tensor_fusion
=
True
self
.
strategy
.
hybrid_configs
[
"sharding_configs"
].
comm_overlap
=
True
self
.
strategy
.
hybrid_configs
[
"sharding_configs"
].
accumulate_steps
=
1
fleet
.
init
(
is_collective
=
True
,
strategy
=
self
.
strategy
)
self
.
data
=
np
.
random
.
randint
(
0
,
...
...
test/legacy_test/test_fused_comm_buffer.py
浏览文件 @
a9f877ff
...
...
@@ -15,7 +15,7 @@
import
unittest
import
paddle
from
paddle.distributed.fleet.
meta_parallel.pp_utils.utils
import
(
from
paddle.distributed.fleet.
utils.tensor_fusion_helper
import
(
HOOK_ACTION
,
FusedCommBuffer
,
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录