Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a00f5bd4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
a00f5bd4
编写于
8月 01, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
8月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize check finite when using sharding comm overlap. (#55766)
上级
dc82fa96
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
164 addition
and
17 deletion
+164
-17
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
...ptimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
+22
-10
python/paddle/distributed/fleet/scaler.py
python/paddle/distributed/fleet/scaler.py
+22
-7
test/collective/fleet/hybrid_parallel_sharding_model_with_fusion_amp.py
...e/fleet/hybrid_parallel_sharding_model_with_fusion_amp.py
+117
-0
test/collective/fleet/test_parallel_dygraph_sharding_parallel.py
...llective/fleet/test_parallel_dygraph_sharding_parallel.py
+3
-0
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
浏览文件 @
a00f5bd4
...
...
@@ -96,20 +96,17 @@ class DygraphShardingOptimizer:
self
.
_rank2params
=
self
.
_partition_parameters
()
self
.
_param2rank
=
self
.
_map_param_to_rank
()
if
not
self
.
tensor_fusion
:
self
.
_set_inner_opt_attr
(
'_parameter_list'
,
self
.
_rank2params
[
self
.
_sharding_rank
]
)
self
.
_set_inner_opt_attr
(
'_param_groups'
,
self
.
_rank2params
[
self
.
_sharding_rank
]
)
if
not
self
.
tensor_fusion
and
not
self
.
comm_overlap
:
local_params
=
self
.
_rank2params
[
self
.
_sharding_rank
]
self
.
_set_inner_opt_attr
(
'_parameter_list'
,
local_params
)
self
.
_set_inner_opt_attr
(
'_param_groups'
,
local_params
)
else
:
self
.
_tensor_fusion
()
decay_params
=
[
p
.
name
for
p
in
self
.
_rank2decay
[
self
.
_sharding_rank
]
]
fused_params
=
self
.
_rank2fused
[
self
.
_sharding_rank
]
local_
fused_params
=
self
.
_rank2fused
[
self
.
_sharding_rank
]
apply_decay_param_fun
=
lambda
x
:
x
in
decay_params
all_fused_params
=
[]
...
...
@@ -118,8 +115,15 @@ class DygraphShardingOptimizer:
self
.
_parameter_list
=
all_fused_params
self
.
_param_groups
=
all_fused_params
self
.
_set_inner_opt_attr
(
'_parameter_list'
,
fused_params
)
self
.
_set_inner_opt_attr
(
'_param_groups'
,
fused_params
)
self
.
_set_inner_opt_attr
(
'_parameter_list'
,
local_fused_params
)
self
.
_set_inner_opt_attr
(
'_param_groups'
,
local_fused_params
)
if
self
.
comm_overlap
:
# Only set local param for check finite when comm overlap.
# Under comm overlap, all grads will be communicated before check_finite.
# Therefore, each sharding rank can get all grads' info at check_finite.
# Without comm overlap, all grads will be communicated after check_finite,
# which means each sharding rank should do check_finite to all grads.
self
.
_local_parameter_list
=
local_fused_params
origin_decay_param_fun
=
getattr
(
self
.
_inner_opt
,
'_apply_decay_param_fun'
,
None
)
...
...
@@ -127,6 +131,14 @@ class DygraphShardingOptimizer:
self
.
_set_inner_opt_attr
(
'_apply_decay_param_fun'
,
apply_decay_param_fun
)
# Note: during the tensor fusion for parameters, the allocator will apply for
# some extra GPU memory for the fused big paramters. This extra GPU memory will
# be useless at once the fusion has done. But the Paddle's allocator won't
# release those memory, it will hold that part in the memory poll. So after
# tensor fusion, the 'reserved' memory will increase but the 'allocate' memory
# won't change. To avoid failure on some other applications (such as some nvtx
# operations), here we manulay let the allocator release the cached memory.
paddle
.
device
.
cuda
.
empty_cache
()
def
clear_grad
(
self
,
set_to_zero
=
True
):
"""
...
...
python/paddle/distributed/fleet/scaler.py
浏览文件 @
a00f5bd4
...
...
@@ -47,20 +47,35 @@ def distributed_scaler(scaler):
else
:
param_grads_fp32
.
append
(
param
.
_grad_ivar
())
else
:
param_grads
=
[
param
.
_grad_ivar
()
for
param
in
optimizer
.
_parameter_list
if
param
.
_grad_ivar
()
is
not
None
]
strategy
=
fleet
.
fleet
.
_user_defined_strategy
sharding_stage_1_overlap
=
strategy
.
hybrid_configs
[
'sharding_configs'
].
comm_overlap
if
sharding_stage_1_overlap
:
# If sharding stage 1 enable comm overlap and need do loss scale. Here we have to wait all comm tasks.
# If no need do loss scale, the wait for all comm tasks will do in the optimizer step.
assert
hasattr
(
optimizer
,
"_comm_buffers"
)
assert
hasattr
(
optimizer
,
"_sharding_enable"
)
if
optimizer
.
_sharding_enable
:
# disable origin grad reduce in hybrid optimizer step
optimizer
.
_sharding_enable
=
False
for
buffer
in
optimizer
.
_comm_buffers
:
buffer
.
scale_grads
()
# For sharding stage 1 under comm overlap, each rank only have to check finite for the response params.
# For now, only sharding stage 1 contains this attr, this can be promoted to stage 2 and stage 3.
assert
hasattr
(
optimizer
,
"_local_parameter_list"
)
parameters
=
optimizer
.
_local_parameter_list
else
:
parameters
=
optimizer
.
_parameter_list
param_grads_fp16
=
[
param
.
_grad_ivar
()
for
param
in
optimizer
.
_parameter_list
for
param
in
parameters
if
(
param
.
_grad_ivar
()
is
not
None
)
and
(
param
.
_grad_ivar
().
dtype
==
core
.
VarDesc
.
VarType
.
FP16
)
]
param_grads_fp32
=
[
param
.
_grad_ivar
()
for
param
in
optimizer
.
_parameter_list
for
param
in
parameters
if
(
param
.
_grad_ivar
()
is
not
None
)
and
(
param
.
_grad_ivar
().
dtype
==
core
.
VarDesc
.
VarType
.
FP32
)
]
...
...
test/collective/fleet/hybrid_parallel_sharding_model_with_fusion_amp.py
0 → 100644
浏览文件 @
a00f5bd4
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
from
paddle.distributed
import
fleet
vocab_size
=
20
hidden_size
=
10
inner_size
=
8
output_size
=
10
seq_length
=
2
batch_size
=
4
STEPS
=
10
class
SimpleDPNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
vocab_size
,
hidden_size
,
inner_size
,
output_size
):
super
().
__init__
()
self
.
linear1
=
paddle
.
nn
.
Linear
(
hidden_size
,
inner_size
)
self
.
linear2
=
paddle
.
nn
.
Linear
(
inner_size
,
hidden_size
)
self
.
linear3
=
paddle
.
nn
.
Linear
(
hidden_size
,
output_size
)
self
.
embedding
=
paddle
.
nn
.
Embedding
(
vocab_size
,
hidden_size
)
def
forward
(
self
,
x
):
x
=
self
.
embedding
(
x
)
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
paddle
.
matmul
(
x
,
self
.
embedding
.
weight
,
transpose_y
=
True
)
return
x
class
TestDistSharding
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
strategy
=
fleet
.
DistributedStrategy
()
self
.
strategy
.
hybrid_configs
=
{
"sharding_degree"
:
2
,
"dp_degree"
:
1
,
"mp_degree"
:
1
,
"pp_degree"
:
1
,
}
self
.
strategy
.
hybrid_configs
[
"sharding_configs"
].
tensor_fusion
=
True
self
.
strategy
.
hybrid_configs
[
"sharding_configs"
].
comm_overlap
=
True
self
.
strategy
.
hybrid_configs
[
"sharding_configs"
].
accumulate_steps
=
1
fleet
.
init
(
is_collective
=
True
,
strategy
=
self
.
strategy
)
self
.
data
=
np
.
random
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_length
,
),
)
if
paddle
.
distributed
.
get_rank
()
==
0
:
self
.
batch_sharding
=
paddle
.
to_tensor
(
self
.
data
[:
2
])
else
:
self
.
batch_sharding
=
paddle
.
to_tensor
(
self
.
data
[
2
:])
def
build_optimizer
(
self
,
model
):
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
0.5
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
parameters
=
model
.
parameters
(),
learning_rate
=
0.001
,
weight_decay
=
0.001
,
grad_clip
=
clip
,
)
return
optimizer
def
build_model_optimizer
(
self
):
model
=
SimpleDPNet
(
vocab_size
,
hidden_size
,
inner_size
,
output_size
)
optimizer
=
self
.
build_optimizer
(
model
)
model
,
optimizer
=
paddle
.
amp
.
decorate
(
model
,
optimizers
=
optimizer
,
level
=
"O2"
,
dtype
=
"float16"
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
scaler
=
fleet
.
distributed_scaler
(
scaler
)
model
=
fleet
.
distributed_model
(
model
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
return
model
,
optimizer
,
scaler
def
sharding_model
(
self
):
model
,
optimizer
,
scaler
=
self
.
build_model_optimizer
()
for
idx
in
range
(
STEPS
):
with
paddle
.
amp
.
auto_cast
(
enable
=
True
,
level
=
'O2'
):
output
=
model
(
self
.
batch_sharding
)
loss
=
output
.
mean
()
scaler
.
scale
(
loss
).
backward
()
scaler
.
step
(
optimizer
)
scaler
.
update
()
optimizer
.
clear_grad
()
def
test_sharding_adam
(
self
):
self
.
sharding_model
()
if
__name__
==
"__main__"
:
unittest
.
main
()
test/collective/fleet/test_parallel_dygraph_sharding_parallel.py
浏览文件 @
a00f5bd4
...
...
@@ -33,6 +33,9 @@ class TestHybridParallel(TestMultipleGpus):
def
test_hybrid_parallel_sharding_tensor_fusion
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_sharding_model_with_fusion.py'
)
def
test_hybrid_parallel_sharding_tensor_fusion_amp
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_sharding_model_with_fusion_amp.py'
)
def
test_hybrid_parallel_sharding_state_dict
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_sharding_state_dict.py'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录