Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1a4a1520
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a4a1520
编写于
2月 07, 2023
作者:
W
wuhuachaocoding
提交者:
GitHub
2月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support bfp16 for stage3 and offload. (#49931)
上级
05c9c0a5
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
184 addition
and
16 deletion
+184
-16
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py
...uted/fleet/meta_parallel/sharding/group_sharded_stage3.py
+24
-6
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py
...buted/fleet/meta_parallel/sharding/group_sharded_utils.py
+53
-4
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py
...nittests/collective/fleet/dygraph_group_sharded_stage3.py
+65
-3
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py
.../collective/fleet/dygraph_group_sharded_stage3_offload.py
+42
-3
未找到文件。
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py
浏览文件 @
1a4a1520
...
@@ -45,10 +45,9 @@ def _all_gather(tensor, buffer_size, group):
...
@@ -45,10 +45,9 @@ def _all_gather(tensor, buffer_size, group):
# CUDA alignment 256 bytes
# CUDA alignment 256 bytes
alignment
=
{
alignment
=
{
"gpu"
:
256
,
"cpu"
:
4096
,
"xpu"
:
256
}
"gpu"
:
256
,
}
align
=
{
align
=
{
Type
.
bf16
.
value
:
2
,
Type
.
fp16
.
value
:
2
,
Type
.
fp16
.
value
:
2
,
Type
.
fp32
.
value
:
4
,
Type
.
fp32
.
value
:
4
,
}
}
...
@@ -251,6 +250,11 @@ class GroupShardedStage3(nn.Layer):
...
@@ -251,6 +250,11 @@ class GroupShardedStage3(nn.Layer):
and
param2dtype
[
param
.
name
]
==
Type
.
fp16
.
value
and
param2dtype
[
param
.
name
]
==
Type
.
fp16
.
value
):
):
tmp_var
=
paddle
.
cast
(
tmp_var
,
Type
.
fp16
.
value
)
tmp_var
=
paddle
.
cast
(
tmp_var
,
Type
.
fp16
.
value
)
elif
(
tmp_var
.
dtype
==
Type
.
fp32
.
value
and
param2dtype
[
param
.
name
]
==
Type
.
bf16
.
value
):
tmp_var
=
paddle
.
cast
(
tmp_var
,
Type
.
bf16
.
value
)
tmp_var
.
_share_buffer_to
(
param
)
tmp_var
.
_share_buffer_to
(
param
)
del
tmp_var
del
tmp_var
for
grad_storage
in
self
.
_grad_storages
.
values
():
for
grad_storage
in
self
.
_grad_storages
.
values
():
...
@@ -312,11 +316,14 @@ class GroupShardedStage3(nn.Layer):
...
@@ -312,11 +316,14 @@ class GroupShardedStage3(nn.Layer):
def
_handle_unslice_params
(
self
):
def
_handle_unslice_params
(
self
):
buffer_size
=
dict
()
buffer_size
=
dict
()
buffer_size
[
Type
.
bf16
.
value
]
=
0
buffer_size
[
Type
.
fp32
.
value
]
=
0
buffer_size
[
Type
.
fp32
.
value
]
=
0
buffer_size
[
Type
.
fp16
.
value
]
=
0
buffer_size
[
Type
.
fp16
.
value
]
=
0
for
param
in
self
.
_unslice_params
:
for
param
in
self
.
_unslice_params
:
# Updata optimizer master weights
# Updata optimizer master weights
if
param
.
dtype
==
Type
.
fp16
.
value
and
not
self
.
_offload
:
if
(
param
.
dtype
==
Type
.
fp16
.
value
or
param
.
dtype
==
Type
.
bf16
.
value
)
and
not
self
.
_offload
:
master_tensor
=
paddle
.
cast
(
param
,
Type
.
fp32
.
value
)
master_tensor
=
paddle
.
cast
(
param
,
Type
.
fp32
.
value
)
master_tensor
.
name
=
param
.
name
master_tensor
.
name
=
param
.
name
self
.
_optim
.
_master_weights
[
param
.
name
]
=
master_tensor
self
.
_optim
.
_master_weights
[
param
.
name
]
=
master_tensor
...
@@ -419,10 +426,14 @@ class GroupShardedStage3(nn.Layer):
...
@@ -419,10 +426,14 @@ class GroupShardedStage3(nn.Layer):
assert
isinstance
(
buffer_size
,
int
)
assert
isinstance
(
buffer_size
,
int
)
value
=
(
value
=
(
np
.
zeros
(
buffer_size
,
dtype
=
np
.
float16
)
np
.
zeros
(
buffer_size
,
dtype
=
np
.
float16
)
if
Type
.
fp16
.
value
==
param
.
dtype
if
(
Type
.
fp16
.
value
==
param
.
dtype
or
Type
.
bf16
.
value
==
param
.
dtype
)
else
np
.
zeros
(
buffer_size
,
dtype
=
np
.
float32
)
else
np
.
zeros
(
buffer_size
,
dtype
=
np
.
float32
)
)
)
buffer
=
core
.
eager
.
Tensor
(
value
=
value
,
place
=
core
.
CPUPlace
())
buffer
=
core
.
eager
.
Tensor
(
value
=
value
,
place
=
core
.
CPUPlace
())
if
Type
.
bf16
.
value
==
param
.
dtype
:
buffer
=
buffer
.
cast
(
Type
.
bf16
.
value
)
param_shape
=
param
.
shape
param_shape
=
param
.
shape
origin_state
=
param
.
stop_gradient
origin_state
=
param
.
stop_gradient
...
@@ -462,7 +473,9 @@ class GroupShardedStage3(nn.Layer):
...
@@ -462,7 +473,9 @@ class GroupShardedStage3(nn.Layer):
# Updata optimizer master weights
# Updata optimizer master weights
if
(
if
(
param
.
trainable
param
.
trainable
and
param
.
dtype
==
Type
.
fp16
.
value
and
(
param
.
dtype
==
Type
.
fp16
.
value
or
param
.
dtype
==
Type
.
bf16
.
value
)
and
not
self
.
_offload
and
not
self
.
_offload
):
):
master_tensor
=
paddle
.
cast
(
param
.
fw_storage
,
Type
.
fp32
.
value
)
master_tensor
=
paddle
.
cast
(
param
.
fw_storage
,
Type
.
fp32
.
value
)
...
@@ -1088,6 +1101,11 @@ def _cpu2device(param):
...
@@ -1088,6 +1101,11 @@ def _cpu2device(param):
and
param2dtype
[
param
.
name
]
==
Type
.
fp16
.
value
and
param2dtype
[
param
.
name
]
==
Type
.
fp16
.
value
):
):
tmp_p
=
paddle
.
cast
(
tmp_p
,
Type
.
fp16
.
value
)
tmp_p
=
paddle
.
cast
(
tmp_p
,
Type
.
fp16
.
value
)
elif
(
tmp_p
.
dtype
==
Type
.
fp32
.
value
and
param2dtype
[
param
.
name
]
==
Type
.
bf16
.
value
):
tmp_p
=
paddle
.
cast
(
tmp_p
,
Type
.
bf16
.
value
)
return
tmp_p
return
tmp_p
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py
浏览文件 @
1a4a1520
...
@@ -54,8 +54,12 @@ class GroupShardedClipGrad:
...
@@ -54,8 +54,12 @@ class GroupShardedClipGrad:
@
paddle
.
autograd
.
no_grad
()
@
paddle
.
autograd
.
no_grad
()
def
_dygraph_clip
(
self
,
params_grads
):
def
_dygraph_clip
(
self
,
params_grads
):
sum_square_fp32
,
sum_square_fp16
=
[],
[]
sum_square_fp32
,
sum_square_fp16
,
sum_square_bfp16
=
[],
[],
[]
unslice_params_fp32
,
unslice_params_fp16
=
[],
[]
unslice_params_fp32
,
unslice_params_fp16
,
unslice_params_bfp16
=
(
[],
[],
[],
)
for
p
,
g
in
params_grads
:
for
p
,
g
in
params_grads
:
p_slice
=
True
# using for slice parameter in sharding stage3
p_slice
=
True
# using for slice parameter in sharding stage3
...
@@ -82,6 +86,11 @@ class GroupShardedClipGrad:
...
@@ -82,6 +86,11 @@ class GroupShardedClipGrad:
sum_square_fp32
.
append
(
sum_square
)
sum_square_fp32
.
append
(
sum_square
)
else
:
else
:
unslice_params_fp32
.
append
(
sum_square
)
unslice_params_fp32
.
append
(
sum_square
)
elif
p
.
dtype
==
paddle
.
bfloat16
:
if
p_slice
:
sum_square_bfp16
.
append
(
sum_square
)
else
:
unslice_params_bfp16
.
append
(
sum_square
)
# global norm of non-distributed FP16 params_and_grads
# global norm of non-distributed FP16 params_and_grads
if
len
(
sum_square_fp16
)
==
0
:
if
len
(
sum_square_fp16
)
==
0
:
...
@@ -93,6 +102,16 @@ class GroupShardedClipGrad:
...
@@ -93,6 +102,16 @@ class GroupShardedClipGrad:
global_norm_fp16
,
dtype
=
paddle
.
float32
global_norm_fp16
,
dtype
=
paddle
.
float32
)
)
# global norm of non-distributed BFP16 params_and_grads
if
len
(
sum_square_bfp16
)
==
0
:
global_norm_bfp16
=
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
else
:
global_norm_bfp16
=
paddle
.
concat
(
sum_square_bfp16
)
global_norm_bfp16
=
paddle
.
sum
(
global_norm_bfp16
)
global_norm_bfp16
=
paddle
.
cast
(
global_norm_bfp16
,
dtype
=
paddle
.
float32
)
# global norm of non-distributed FP16 params_and_grads for unslice parameters
# global norm of non-distributed FP16 params_and_grads for unslice parameters
if
len
(
unslice_params_fp16
)
==
0
:
if
len
(
unslice_params_fp16
)
==
0
:
global_unslice_fp16
=
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
global_unslice_fp16
=
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
...
@@ -103,6 +122,16 @@ class GroupShardedClipGrad:
...
@@ -103,6 +122,16 @@ class GroupShardedClipGrad:
global_unslice_fp16
,
dtype
=
paddle
.
float32
global_unslice_fp16
,
dtype
=
paddle
.
float32
)
)
# global norm of non-distributed BFP16 params_and_grads for unslice parameters
if
len
(
unslice_params_bfp16
)
==
0
:
global_unslice_bfp16
=
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
else
:
global_unslice_bfp16
=
paddle
.
concat
(
unslice_params_bfp16
)
global_unslice_bfp16
=
paddle
.
sum
(
global_unslice_bfp16
)
global_unslice_bfp16
=
paddle
.
cast
(
global_unslice_bfp16
,
dtype
=
paddle
.
float32
)
# global norm of non-distributed FP32 params_and_grads
# global norm of non-distributed FP32 params_and_grads
global_norm_fp32
=
(
global_norm_fp32
=
(
paddle
.
concat
(
sum_square_fp32
)
paddle
.
concat
(
sum_square_fp32
)
...
@@ -118,9 +147,13 @@ class GroupShardedClipGrad:
...
@@ -118,9 +147,13 @@ class GroupShardedClipGrad:
else
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
else
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
)
)
global_unslice_fp32
=
paddle
.
sum
(
global_unslice_fp32
)
global_unslice_fp32
=
paddle
.
sum
(
global_unslice_fp32
)
global_unslice_var
=
global_unslice_fp16
+
global_unslice_fp32
global_unslice_var
=
(
global_unslice_fp16
+
global_unslice_fp32
+
global_unslice_bfp16
)
global_norm_var
=
global_norm_fp16
+
global_norm_fp32
global_norm_var
=
(
global_norm_fp16
+
global_norm_fp32
+
global_norm_bfp16
)
# add all reduce to get global norm of distributed params_and_grads
# add all reduce to get global norm of distributed params_and_grads
dev_id
=
int
(
self
.
_device
.
split
(
":"
)[
1
])
dev_id
=
int
(
self
.
_device
.
split
(
":"
)[
1
])
...
@@ -181,6 +214,7 @@ def GroupShardedScaler(scaler):
...
@@ -181,6 +214,7 @@ def GroupShardedScaler(scaler):
if
not
self
.
_enable
:
if
not
self
.
_enable
:
return
return
param_grads
=
[]
param_grads
=
[]
param_grads_bfp16
=
[]
param_grads_fp16
=
[]
param_grads_fp16
=
[]
param_grads_fp32
=
[]
param_grads_fp32
=
[]
if
hasattr
(
optimizer
,
"update_slice"
):
if
hasattr
(
optimizer
,
"update_slice"
):
...
@@ -200,6 +234,8 @@ def GroupShardedScaler(scaler):
...
@@ -200,6 +234,8 @@ def GroupShardedScaler(scaler):
paddle
.
float16
,
paddle
.
float16
,
]:
]:
param_grads_fp16
.
append
(
param
.
grad
)
param_grads_fp16
.
append
(
param
.
grad
)
elif
param
.
grad
.
dtype
in
[
paddle
.
bfloat16
]:
param_grads_bfp16
.
append
(
param
.
grad
)
else
:
else
:
param_grads_fp32
.
append
(
param
.
grad
)
param_grads_fp32
.
append
(
param
.
grad
)
else
:
else
:
...
@@ -211,10 +247,13 @@ def GroupShardedScaler(scaler):
...
@@ -211,10 +247,13 @@ def GroupShardedScaler(scaler):
paddle
.
float16
,
paddle
.
float16
,
]:
]:
param_grads_fp16
.
append
(
param
.
grad
)
param_grads_fp16
.
append
(
param
.
grad
)
elif
param
.
grad
.
dtype
in
[
paddle
.
bfloat16
]:
param_grads_bfp16
.
append
(
param
.
grad
)
else
:
else
:
param_grads_fp32
.
append
(
param
.
grad
)
param_grads_fp32
.
append
(
param
.
grad
)
temp_found_inf_fp16
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool_
))
temp_found_inf_fp16
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool_
))
temp_found_inf_bfp16
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool_
))
temp_found_inf_fp32
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool_
))
temp_found_inf_fp32
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool_
))
device
=
paddle
.
get_device
().
split
(
":"
)[
0
]
device
=
paddle
.
get_device
().
split
(
":"
)[
0
]
...
@@ -224,6 +263,16 @@ def GroupShardedScaler(scaler):
...
@@ -224,6 +263,16 @@ def GroupShardedScaler(scaler):
)
)
with
device_guard
(
dev_id
,
device
):
with
device_guard
(
dev_id
,
device
):
if
len
(
param_grads_bfp16
):
_legacy_C_ops
.
check_finite_and_unscale
(
param_grads_bfp16
,
self
.
_scale
,
param_grads_bfp16
,
temp_found_inf_bfp16
,
)
self
.
_found_inf
=
_C_ops
.
bitwise_or
(
self
.
_found_inf
,
temp_found_inf_bfp16
)
if
len
(
param_grads_fp16
):
if
len
(
param_grads_fp16
):
_legacy_C_ops
.
check_finite_and_unscale
(
_legacy_C_ops
.
check_finite_and_unscale
(
param_grads_fp16
,
param_grads_fp16
,
...
...
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py
浏览文件 @
1a4a1520
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
os
import
os
import
shutil
import
shutil
import
subprocess
import
tempfile
import
tempfile
import
numpy
as
np
import
numpy
as
np
...
@@ -135,6 +136,7 @@ def train_mlp(
...
@@ -135,6 +136,7 @@ def train_mlp(
model
,
model
,
sharding_stage
,
sharding_stage
,
use_pure_fp16
=
False
,
use_pure_fp16
=
False
,
use_bfp16
=
False
,
accumulate_grad
=
False
,
accumulate_grad
=
False
,
batch_size
=
100
,
batch_size
=
100
,
opt_group
=
False
,
opt_group
=
False
,
...
@@ -154,7 +156,10 @@ def train_mlp(
...
@@ -154,7 +156,10 @@ def train_mlp(
if
use_pure_fp16
:
if
use_pure_fp16
:
model
=
paddle
.
amp
.
decorate
(
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
,
dtype
=
'bfloat16'
if
use_bfp16
else
'float16'
,
)
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
32768
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
32768
)
scaler
=
GroupShardedScaler
(
scaler
)
scaler
=
GroupShardedScaler
(
scaler
)
...
@@ -201,7 +206,11 @@ def train_mlp(
...
@@ -201,7 +206,11 @@ def train_mlp(
img
,
label
=
data
img
,
label
=
data
label
.
stop_gradient
=
True
label
.
stop_gradient
=
True
img
.
stop_gradient
=
True
img
.
stop_gradient
=
True
with
paddle
.
amp
.
auto_cast
(
use_pure_fp16
,
level
=
'O2'
):
with
paddle
.
amp
.
auto_cast
(
use_pure_fp16
,
level
=
'O2'
,
dtype
=
'bfloat16'
if
use_bfp16
else
'float16'
,
):
out
=
model
(
img
)
out
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
input
=
out
,
label
=
label
...
@@ -240,7 +249,23 @@ def train_mlp(
...
@@ -240,7 +249,23 @@ def train_mlp(
def
test_stage2_stage3
():
def
test_stage2_stage3
():
paddle
.
distributed
.
init_parallel_env
()
paddle
.
distributed
.
init_parallel_env
()
mlp
,
mlp1
,
mlp2
,
mlp3
,
mlp4
,
mlp5
,
mlp6
,
mlp7
,
mlp8
,
mlp9
,
mlp10
=
(
(
mlp
,
mlp1
,
mlp2
,
mlp3
,
mlp4
,
mlp5
,
mlp6
,
mlp7
,
mlp8
,
mlp9
,
mlp10
,
mlp11
,
mlp12
,
)
=
(
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
...
@@ -264,6 +289,8 @@ def test_stage2_stage3():
...
@@ -264,6 +289,8 @@ def test_stage2_stage3():
mlp8
.
set_state_dict
(
state_dict
)
mlp8
.
set_state_dict
(
state_dict
)
mlp9
.
set_state_dict
(
state_dict
)
mlp9
.
set_state_dict
(
state_dict
)
mlp10
.
set_state_dict
(
state_dict
)
mlp10
.
set_state_dict
(
state_dict
)
mlp11
.
set_state_dict
(
state_dict
)
mlp12
.
set_state_dict
(
state_dict
)
# fp32
# fp32
stage2_params
=
train_mlp
(
stage2_params
=
train_mlp
(
...
@@ -336,6 +363,41 @@ def test_stage2_stage3():
...
@@ -336,6 +363,41 @@ def test_stage2_stage3():
stage3_params
[
i
].
numpy
(),
stage3_params_re
[
i
].
numpy
(),
rtol
=
1e-6
stage3_params
[
i
].
numpy
(),
stage3_params_re
[
i
].
numpy
(),
rtol
=
1e-6
)
)
# bfp16
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str
=
subprocess
.
check_output
(
r
"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'"
,
stderr
=
subprocess
.
DEVNULL
,
shell
=
True
,
).
decode
(
'utf-8'
)
nccl_version
=
(
int
(
""
.
join
(
nccl_version_str
.
split
(
"."
)))
if
nccl_version_str
else
0
)
if
nccl_version
>=
2100
:
stage2_params
=
train_mlp
(
mlp11
,
sharding_stage
=
2
,
use_pure_fp16
=
True
,
opt_group
=
False
,
use_bfp16
=
True
,
)
stage3_params
=
train_mlp
(
mlp12
,
sharding_stage
=
3
,
use_pure_fp16
=
True
,
opt_group
=
False
,
use_bfp16
=
True
,
)
for
i
in
range
(
len
(
stage2_params
)):
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
(),
stage3_params
[
i
].
numpy
(),
rtol
=
1e-4
,
atol
=
1e-3
,
)
# test for share layer parameters and exclude_layer function.
# test for share layer parameters and exclude_layer function.
sm1
,
sm2
,
sm3
,
sm4
=
(
sm1
,
sm2
,
sm3
,
sm4
=
(
SpecialModel
(),
SpecialModel
(),
...
...
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py
浏览文件 @
1a4a1520
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
subprocess
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
...
@@ -84,6 +86,7 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
...
@@ -84,6 +86,7 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
def
train_mlp
(
def
train_mlp
(
model
,
model
,
use_pure_fp16
=
False
,
use_pure_fp16
=
False
,
use_bfp16
=
False
,
accumulate_grad
=
False
,
accumulate_grad
=
False
,
offload
=
False
,
offload
=
False
,
batch_size
=
100
,
batch_size
=
100
,
...
@@ -94,7 +97,10 @@ def train_mlp(
...
@@ -94,7 +97,10 @@ def train_mlp(
if
use_pure_fp16
:
if
use_pure_fp16
:
model
=
paddle
.
amp
.
decorate
(
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
,
dtype
=
'bfloat16'
if
use_bfp16
else
'float16'
,
)
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
32768
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
32768
)
scaler
=
GroupShardedScaler
(
scaler
)
scaler
=
GroupShardedScaler
(
scaler
)
...
@@ -123,7 +129,11 @@ def train_mlp(
...
@@ -123,7 +129,11 @@ def train_mlp(
img
,
label
=
data
img
,
label
=
data
label
.
stop_gradient
=
True
label
.
stop_gradient
=
True
img
.
stop_gradient
=
True
img
.
stop_gradient
=
True
with
paddle
.
amp
.
auto_cast
(
use_pure_fp16
,
level
=
'O2'
):
with
paddle
.
amp
.
auto_cast
(
use_pure_fp16
,
level
=
'O2'
,
dtype
=
'bfloat16'
if
use_bfp16
else
'float16'
,
):
out
=
model
(
img
)
out
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
input
=
out
,
label
=
label
...
@@ -161,7 +171,9 @@ def train_mlp(
...
@@ -161,7 +171,9 @@ def train_mlp(
def
test_stage3_offload
():
def
test_stage3_offload
():
paddle
.
distributed
.
init_parallel_env
()
paddle
.
distributed
.
init_parallel_env
()
mlp
,
mlp1
,
mlp2
,
mlp3
,
mlp4
,
mlp5
,
mlp6
=
(
mlp
,
mlp1
,
mlp2
,
mlp3
,
mlp4
,
mlp5
,
mlp6
,
mlp7
,
mlp8
=
(
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
...
@@ -177,6 +189,8 @@ def test_stage3_offload():
...
@@ -177,6 +189,8 @@ def test_stage3_offload():
mlp4
.
set_state_dict
(
state_dict
)
mlp4
.
set_state_dict
(
state_dict
)
mlp5
.
set_state_dict
(
state_dict
)
mlp5
.
set_state_dict
(
state_dict
)
mlp6
.
set_state_dict
(
state_dict
)
mlp6
.
set_state_dict
(
state_dict
)
mlp7
.
set_state_dict
(
state_dict
)
mlp8
.
set_state_dict
(
state_dict
)
# fp32 offload
# fp32 offload
stage3_params
=
train_mlp
(
mlp1
,
use_pure_fp16
=
False
)
stage3_params
=
train_mlp
(
mlp1
,
use_pure_fp16
=
False
)
...
@@ -200,6 +214,31 @@ def test_stage3_offload():
...
@@ -200,6 +214,31 @@ def test_stage3_offload():
atol
=
1e-2
,
atol
=
1e-2
,
)
)
# bfp16 offload
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str
=
subprocess
.
check_output
(
r
"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'"
,
stderr
=
subprocess
.
DEVNULL
,
shell
=
True
,
).
decode
(
'utf-8'
)
nccl_version
=
(
int
(
""
.
join
(
nccl_version_str
.
split
(
"."
)))
if
nccl_version_str
else
0
)
if
nccl_version
>=
2100
:
stage3_params
=
train_mlp
(
mlp7
,
use_pure_fp16
=
True
,
use_bfp16
=
True
)
stage3_params_offload
=
train_mlp
(
mlp8
,
use_pure_fp16
=
True
,
offload
=
True
,
use_bfp16
=
True
)
for
i
in
range
(
len
(
stage3_params
)):
np
.
testing
.
assert_allclose
(
stage3_params
[
i
].
numpy
(),
stage3_params_offload
[
i
].
numpy
(),
rtol
=
1e-2
,
atol
=
1e-2
,
)
# fp32 accumulate grad offload
# fp32 accumulate grad offload
stage3_params
=
train_mlp
(
stage3_params
=
train_mlp
(
mlp5
,
use_pure_fp16
=
False
,
batch_size
=
20
,
accumulate_grad
=
True
mlp5
,
use_pure_fp16
=
False
,
batch_size
=
20
,
accumulate_grad
=
True
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录