Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1a4a1520
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a4a1520
编写于
2月 07, 2023
作者:
W
wuhuachaocoding
提交者:
GitHub
2月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support bfp16 for stage3 and offload. (#49931)
上级
05c9c0a5
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
184 addition
and
16 deletion
+184
-16
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py
...uted/fleet/meta_parallel/sharding/group_sharded_stage3.py
+24
-6
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py
...buted/fleet/meta_parallel/sharding/group_sharded_utils.py
+53
-4
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py
...nittests/collective/fleet/dygraph_group_sharded_stage3.py
+65
-3
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py
.../collective/fleet/dygraph_group_sharded_stage3_offload.py
+42
-3
未找到文件。
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py
浏览文件 @
1a4a1520
...
@@ -45,10 +45,9 @@ def _all_gather(tensor, buffer_size, group):
...
@@ -45,10 +45,9 @@ def _all_gather(tensor, buffer_size, group):
# CUDA alignment 256 bytes
# CUDA alignment 256 bytes
alignment
=
{
alignment
=
{
"gpu"
:
256
,
"cpu"
:
4096
,
"xpu"
:
256
}
"gpu"
:
256
,
}
align
=
{
align
=
{
Type
.
bf16
.
value
:
2
,
Type
.
fp16
.
value
:
2
,
Type
.
fp16
.
value
:
2
,
Type
.
fp32
.
value
:
4
,
Type
.
fp32
.
value
:
4
,
}
}
...
@@ -251,6 +250,11 @@ class GroupShardedStage3(nn.Layer):
...
@@ -251,6 +250,11 @@ class GroupShardedStage3(nn.Layer):
and
param2dtype
[
param
.
name
]
==
Type
.
fp16
.
value
and
param2dtype
[
param
.
name
]
==
Type
.
fp16
.
value
):
):
tmp_var
=
paddle
.
cast
(
tmp_var
,
Type
.
fp16
.
value
)
tmp_var
=
paddle
.
cast
(
tmp_var
,
Type
.
fp16
.
value
)
elif
(
tmp_var
.
dtype
==
Type
.
fp32
.
value
and
param2dtype
[
param
.
name
]
==
Type
.
bf16
.
value
):
tmp_var
=
paddle
.
cast
(
tmp_var
,
Type
.
bf16
.
value
)
tmp_var
.
_share_buffer_to
(
param
)
tmp_var
.
_share_buffer_to
(
param
)
del
tmp_var
del
tmp_var
for
grad_storage
in
self
.
_grad_storages
.
values
():
for
grad_storage
in
self
.
_grad_storages
.
values
():
...
@@ -312,11 +316,14 @@ class GroupShardedStage3(nn.Layer):
...
@@ -312,11 +316,14 @@ class GroupShardedStage3(nn.Layer):
def
_handle_unslice_params
(
self
):
def
_handle_unslice_params
(
self
):
buffer_size
=
dict
()
buffer_size
=
dict
()
buffer_size
[
Type
.
bf16
.
value
]
=
0
buffer_size
[
Type
.
fp32
.
value
]
=
0
buffer_size
[
Type
.
fp32
.
value
]
=
0
buffer_size
[
Type
.
fp16
.
value
]
=
0
buffer_size
[
Type
.
fp16
.
value
]
=
0
for
param
in
self
.
_unslice_params
:
for
param
in
self
.
_unslice_params
:
# Updata optimizer master weights
# Updata optimizer master weights
if
param
.
dtype
==
Type
.
fp16
.
value
and
not
self
.
_offload
:
if
(
param
.
dtype
==
Type
.
fp16
.
value
or
param
.
dtype
==
Type
.
bf16
.
value
)
and
not
self
.
_offload
:
master_tensor
=
paddle
.
cast
(
param
,
Type
.
fp32
.
value
)
master_tensor
=
paddle
.
cast
(
param
,
Type
.
fp32
.
value
)
master_tensor
.
name
=
param
.
name
master_tensor
.
name
=
param
.
name
self
.
_optim
.
_master_weights
[
param
.
name
]
=
master_tensor
self
.
_optim
.
_master_weights
[
param
.
name
]
=
master_tensor
...
@@ -419,10 +426,14 @@ class GroupShardedStage3(nn.Layer):
...
@@ -419,10 +426,14 @@ class GroupShardedStage3(nn.Layer):
assert
isinstance
(
buffer_size
,
int
)
assert
isinstance
(
buffer_size
,
int
)
value
=
(
value
=
(
np
.
zeros
(
buffer_size
,
dtype
=
np
.
float16
)
np
.
zeros
(
buffer_size
,
dtype
=
np
.
float16
)
if
Type
.
fp16
.
value
==
param
.
dtype
if
(
Type
.
fp16
.
value
==
param
.
dtype
or
Type
.
bf16
.
value
==
param
.
dtype
)
else
np
.
zeros
(
buffer_size
,
dtype
=
np
.
float32
)
else
np
.
zeros
(
buffer_size
,
dtype
=
np
.
float32
)
)
)
buffer
=
core
.
eager
.
Tensor
(
value
=
value
,
place
=
core
.
CPUPlace
())
buffer
=
core
.
eager
.
Tensor
(
value
=
value
,
place
=
core
.
CPUPlace
())
if
Type
.
bf16
.
value
==
param
.
dtype
:
buffer
=
buffer
.
cast
(
Type
.
bf16
.
value
)
param_shape
=
param
.
shape
param_shape
=
param
.
shape
origin_state
=
param
.
stop_gradient
origin_state
=
param
.
stop_gradient
...
@@ -462,7 +473,9 @@ class GroupShardedStage3(nn.Layer):
...
@@ -462,7 +473,9 @@ class GroupShardedStage3(nn.Layer):
# Updata optimizer master weights
# Updata optimizer master weights
if
(
if
(
param
.
trainable
param
.
trainable
and
param
.
dtype
==
Type
.
fp16
.
value
and
(
param
.
dtype
==
Type
.
fp16
.
value
or
param
.
dtype
==
Type
.
bf16
.
value
)
and
not
self
.
_offload
and
not
self
.
_offload
):
):
master_tensor
=
paddle
.
cast
(
param
.
fw_storage
,
Type
.
fp32
.
value
)
master_tensor
=
paddle
.
cast
(
param
.
fw_storage
,
Type
.
fp32
.
value
)
...
@@ -1088,6 +1101,11 @@ def _cpu2device(param):
...
@@ -1088,6 +1101,11 @@ def _cpu2device(param):
and
param2dtype
[
param
.
name
]
==
Type
.
fp16
.
value
and
param2dtype
[
param
.
name
]
==
Type
.
fp16
.
value
):
):
tmp_p
=
paddle
.
cast
(
tmp_p
,
Type
.
fp16
.
value
)
tmp_p
=
paddle
.
cast
(
tmp_p
,
Type
.
fp16
.
value
)
elif
(
tmp_p
.
dtype
==
Type
.
fp32
.
value
and
param2dtype
[
param
.
name
]
==
Type
.
bf16
.
value
):
tmp_p
=
paddle
.
cast
(
tmp_p
,
Type
.
bf16
.
value
)
return
tmp_p
return
tmp_p
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py
浏览文件 @
1a4a1520
...
@@ -54,8 +54,12 @@ class GroupShardedClipGrad:
...
@@ -54,8 +54,12 @@ class GroupShardedClipGrad:
@
paddle
.
autograd
.
no_grad
()
@
paddle
.
autograd
.
no_grad
()
def
_dygraph_clip
(
self
,
params_grads
):
def
_dygraph_clip
(
self
,
params_grads
):
sum_square_fp32
,
sum_square_fp16
=
[],
[]
sum_square_fp32
,
sum_square_fp16
,
sum_square_bfp16
=
[],
[],
[]
unslice_params_fp32
,
unslice_params_fp16
=
[],
[]
unslice_params_fp32
,
unslice_params_fp16
,
unslice_params_bfp16
=
(
[],
[],
[],
)
for
p
,
g
in
params_grads
:
for
p
,
g
in
params_grads
:
p_slice
=
True
# using for slice parameter in sharding stage3
p_slice
=
True
# using for slice parameter in sharding stage3
...
@@ -82,6 +86,11 @@ class GroupShardedClipGrad:
...
@@ -82,6 +86,11 @@ class GroupShardedClipGrad:
sum_square_fp32
.
append
(
sum_square
)
sum_square_fp32
.
append
(
sum_square
)
else
:
else
:
unslice_params_fp32
.
append
(
sum_square
)
unslice_params_fp32
.
append
(
sum_square
)
elif
p
.
dtype
==
paddle
.
bfloat16
:
if
p_slice
:
sum_square_bfp16
.
append
(
sum_square
)
else
:
unslice_params_bfp16
.
append
(
sum_square
)
# global norm of non-distributed FP16 params_and_grads
# global norm of non-distributed FP16 params_and_grads
if
len
(
sum_square_fp16
)
==
0
:
if
len
(
sum_square_fp16
)
==
0
:
...
@@ -93,6 +102,16 @@ class GroupShardedClipGrad:
...
@@ -93,6 +102,16 @@ class GroupShardedClipGrad:
global_norm_fp16
,
dtype
=
paddle
.
float32
global_norm_fp16
,
dtype
=
paddle
.
float32
)
)
# global norm of non-distributed BFP16 params_and_grads
if
len
(
sum_square_bfp16
)
==
0
:
global_norm_bfp16
=
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
else
:
global_norm_bfp16
=
paddle
.
concat
(
sum_square_bfp16
)
global_norm_bfp16
=
paddle
.
sum
(
global_norm_bfp16
)
global_norm_bfp16
=
paddle
.
cast
(
global_norm_bfp16
,
dtype
=
paddle
.
float32
)
# global norm of non-distributed FP16 params_and_grads for unslice parameters
# global norm of non-distributed FP16 params_and_grads for unslice parameters
if
len
(
unslice_params_fp16
)
==
0
:
if
len
(
unslice_params_fp16
)
==
0
:
global_unslice_fp16
=
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
global_unslice_fp16
=
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
...
@@ -103,6 +122,16 @@ class GroupShardedClipGrad:
...
@@ -103,6 +122,16 @@ class GroupShardedClipGrad:
global_unslice_fp16
,
dtype
=
paddle
.
float32
global_unslice_fp16
,
dtype
=
paddle
.
float32
)
)
# global norm of non-distributed BFP16 params_and_grads for unslice parameters
if
len
(
unslice_params_bfp16
)
==
0
:
global_unslice_bfp16
=
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
else
:
global_unslice_bfp16
=
paddle
.
concat
(
unslice_params_bfp16
)
global_unslice_bfp16
=
paddle
.
sum
(
global_unslice_bfp16
)
global_unslice_bfp16
=
paddle
.
cast
(
global_unslice_bfp16
,
dtype
=
paddle
.
float32
)
# global norm of non-distributed FP32 params_and_grads
# global norm of non-distributed FP32 params_and_grads
global_norm_fp32
=
(
global_norm_fp32
=
(
paddle
.
concat
(
sum_square_fp32
)
paddle
.
concat
(
sum_square_fp32
)
...
@@ -118,9 +147,13 @@ class GroupShardedClipGrad:
...
@@ -118,9 +147,13 @@ class GroupShardedClipGrad:
else
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
else
paddle
.
to_tensor
([
0.0
],
dtype
=
paddle
.
float32
)
)
)
global_unslice_fp32
=
paddle
.
sum
(
global_unslice_fp32
)
global_unslice_fp32
=
paddle
.
sum
(
global_unslice_fp32
)
global_unslice_var
=
global_unslice_fp16
+
global_unslice_fp32
global_unslice_var
=
(
global_unslice_fp16
+
global_unslice_fp32
+
global_unslice_bfp16
)
global_norm_var
=
global_norm_fp16
+
global_norm_fp32
global_norm_var
=
(
global_norm_fp16
+
global_norm_fp32
+
global_norm_bfp16
)
# add all reduce to get global norm of distributed params_and_grads
# add all reduce to get global norm of distributed params_and_grads
dev_id
=
int
(
self
.
_device
.
split
(
":"
)[
1
])
dev_id
=
int
(
self
.
_device
.
split
(
":"
)[
1
])
...
@@ -181,6 +214,7 @@ def GroupShardedScaler(scaler):
...
@@ -181,6 +214,7 @@ def GroupShardedScaler(scaler):
if
not
self
.
_enable
:
if
not
self
.
_enable
:
return
return
param_grads
=
[]
param_grads
=
[]
param_grads_bfp16
=
[]
param_grads_fp16
=
[]
param_grads_fp16
=
[]
param_grads_fp32
=
[]
param_grads_fp32
=
[]
if
hasattr
(
optimizer
,
"update_slice"
):
if
hasattr
(
optimizer
,
"update_slice"
):
...
@@ -200,6 +234,8 @@ def GroupShardedScaler(scaler):
...
@@ -200,6 +234,8 @@ def GroupShardedScaler(scaler):
paddle
.
float16
,
paddle
.
float16
,
]:
]:
param_grads_fp16
.
append
(
param
.
grad
)
param_grads_fp16
.
append
(
param
.
grad
)
elif
param
.
grad
.
dtype
in
[
paddle
.
bfloat16
]:
param_grads_bfp16
.
append
(
param
.
grad
)
else
:
else
:
param_grads_fp32
.
append
(
param
.
grad
)
param_grads_fp32
.
append
(
param
.
grad
)
else
:
else
:
...
@@ -211,10 +247,13 @@ def GroupShardedScaler(scaler):
...
@@ -211,10 +247,13 @@ def GroupShardedScaler(scaler):
paddle
.
float16
,
paddle
.
float16
,
]:
]:
param_grads_fp16
.
append
(
param
.
grad
)
param_grads_fp16
.
append
(
param
.
grad
)
elif
param
.
grad
.
dtype
in
[
paddle
.
bfloat16
]:
param_grads_bfp16
.
append
(
param
.
grad
)
else
:
else
:
param_grads_fp32
.
append
(
param
.
grad
)
param_grads_fp32
.
append
(
param
.
grad
)
temp_found_inf_fp16
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool_
))
temp_found_inf_fp16
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool_
))
temp_found_inf_bfp16
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool_
))
temp_found_inf_fp32
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool_
))
temp_found_inf_fp32
=
to_variable
(
np
.
array
([
0
]).
astype
(
np
.
bool_
))
device
=
paddle
.
get_device
().
split
(
":"
)[
0
]
device
=
paddle
.
get_device
().
split
(
":"
)[
0
]
...
@@ -224,6 +263,16 @@ def GroupShardedScaler(scaler):
...
@@ -224,6 +263,16 @@ def GroupShardedScaler(scaler):
)
)
with
device_guard
(
dev_id
,
device
):
with
device_guard
(
dev_id
,
device
):
if
len
(
param_grads_bfp16
):
_legacy_C_ops
.
check_finite_and_unscale
(
param_grads_bfp16
,
self
.
_scale
,
param_grads_bfp16
,
temp_found_inf_bfp16
,
)
self
.
_found_inf
=
_C_ops
.
bitwise_or
(
self
.
_found_inf
,
temp_found_inf_bfp16
)
if
len
(
param_grads_fp16
):
if
len
(
param_grads_fp16
):
_legacy_C_ops
.
check_finite_and_unscale
(
_legacy_C_ops
.
check_finite_and_unscale
(
param_grads_fp16
,
param_grads_fp16
,
...
...
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py
浏览文件 @
1a4a1520
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
os
import
os
import
shutil
import
shutil
import
subprocess
import
tempfile
import
tempfile
import
numpy
as
np
import
numpy
as
np
...
@@ -135,6 +136,7 @@ def train_mlp(
...
@@ -135,6 +136,7 @@ def train_mlp(
model
,
model
,
sharding_stage
,
sharding_stage
,
use_pure_fp16
=
False
,
use_pure_fp16
=
False
,
use_bfp16
=
False
,
accumulate_grad
=
False
,
accumulate_grad
=
False
,
batch_size
=
100
,
batch_size
=
100
,
opt_group
=
False
,
opt_group
=
False
,
...
@@ -154,7 +156,10 @@ def train_mlp(
...
@@ -154,7 +156,10 @@ def train_mlp(
if
use_pure_fp16
:
if
use_pure_fp16
:
model
=
paddle
.
amp
.
decorate
(
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
,
dtype
=
'bfloat16'
if
use_bfp16
else
'float16'
,
)
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
32768
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
32768
)
scaler
=
GroupShardedScaler
(
scaler
)
scaler
=
GroupShardedScaler
(
scaler
)
...
@@ -201,7 +206,11 @@ def train_mlp(
...
@@ -201,7 +206,11 @@ def train_mlp(
img
,
label
=
data
img
,
label
=
data
label
.
stop_gradient
=
True
label
.
stop_gradient
=
True
img
.
stop_gradient
=
True
img
.
stop_gradient
=
True
with
paddle
.
amp
.
auto_cast
(
use_pure_fp16
,
level
=
'O2'
):
with
paddle
.
amp
.
auto_cast
(
use_pure_fp16
,
level
=
'O2'
,
dtype
=
'bfloat16'
if
use_bfp16
else
'float16'
,
):
out
=
model
(
img
)
out
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
input
=
out
,
label
=
label
...
@@ -240,7 +249,23 @@ def train_mlp(
...
@@ -240,7 +249,23 @@ def train_mlp(
def
test_stage2_stage3
():
def
test_stage2_stage3
():
paddle
.
distributed
.
init_parallel_env
()
paddle
.
distributed
.
init_parallel_env
()
mlp
,
mlp1
,
mlp2
,
mlp3
,
mlp4
,
mlp5
,
mlp6
,
mlp7
,
mlp8
,
mlp9
,
mlp10
=
(
(
mlp
,
mlp1
,
mlp2
,
mlp3
,
mlp4
,
mlp5
,
mlp6
,
mlp7
,
mlp8
,
mlp9
,
mlp10
,
mlp11
,
mlp12
,
)
=
(
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
...
@@ -264,6 +289,8 @@ def test_stage2_stage3():
...
@@ -264,6 +289,8 @@ def test_stage2_stage3():
mlp8
.
set_state_dict
(
state_dict
)
mlp8
.
set_state_dict
(
state_dict
)
mlp9
.
set_state_dict
(
state_dict
)
mlp9
.
set_state_dict
(
state_dict
)
mlp10
.
set_state_dict
(
state_dict
)
mlp10
.
set_state_dict
(
state_dict
)
mlp11
.
set_state_dict
(
state_dict
)
mlp12
.
set_state_dict
(
state_dict
)
# fp32
# fp32
stage2_params
=
train_mlp
(
stage2_params
=
train_mlp
(
...
@@ -336,6 +363,41 @@ def test_stage2_stage3():
...
@@ -336,6 +363,41 @@ def test_stage2_stage3():
stage3_params
[
i
].
numpy
(),
stage3_params_re
[
i
].
numpy
(),
rtol
=
1e-6
stage3_params
[
i
].
numpy
(),
stage3_params_re
[
i
].
numpy
(),
rtol
=
1e-6
)
)
# bfp16
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str
=
subprocess
.
check_output
(
r
"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'"
,
stderr
=
subprocess
.
DEVNULL
,
shell
=
True
,
).
decode
(
'utf-8'
)
nccl_version
=
(
int
(
""
.
join
(
nccl_version_str
.
split
(
"."
)))
if
nccl_version_str
else
0
)
if
nccl_version
>=
2100
:
stage2_params
=
train_mlp
(
mlp11
,
sharding_stage
=
2
,
use_pure_fp16
=
True
,
opt_group
=
False
,
use_bfp16
=
True
,
)
stage3_params
=
train_mlp
(
mlp12
,
sharding_stage
=
3
,
use_pure_fp16
=
True
,
opt_group
=
False
,
use_bfp16
=
True
,
)
for
i
in
range
(
len
(
stage2_params
)):
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
(),
stage3_params
[
i
].
numpy
(),
rtol
=
1e-4
,
atol
=
1e-3
,
)
# test for share layer parameters and exclude_layer function.
# test for share layer parameters and exclude_layer function.
sm1
,
sm2
,
sm3
,
sm4
=
(
sm1
,
sm2
,
sm3
,
sm4
=
(
SpecialModel
(),
SpecialModel
(),
...
...
python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py
浏览文件 @
1a4a1520
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
subprocess
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
...
@@ -84,6 +86,7 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
...
@@ -84,6 +86,7 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
def
train_mlp
(
def
train_mlp
(
model
,
model
,
use_pure_fp16
=
False
,
use_pure_fp16
=
False
,
use_bfp16
=
False
,
accumulate_grad
=
False
,
accumulate_grad
=
False
,
offload
=
False
,
offload
=
False
,
batch_size
=
100
,
batch_size
=
100
,
...
@@ -94,7 +97,10 @@ def train_mlp(
...
@@ -94,7 +97,10 @@ def train_mlp(
if
use_pure_fp16
:
if
use_pure_fp16
:
model
=
paddle
.
amp
.
decorate
(
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
,
dtype
=
'bfloat16'
if
use_bfp16
else
'float16'
,
)
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
32768
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
32768
)
scaler
=
GroupShardedScaler
(
scaler
)
scaler
=
GroupShardedScaler
(
scaler
)
...
@@ -123,7 +129,11 @@ def train_mlp(
...
@@ -123,7 +129,11 @@ def train_mlp(
img
,
label
=
data
img
,
label
=
data
label
.
stop_gradient
=
True
label
.
stop_gradient
=
True
img
.
stop_gradient
=
True
img
.
stop_gradient
=
True
with
paddle
.
amp
.
auto_cast
(
use_pure_fp16
,
level
=
'O2'
):
with
paddle
.
amp
.
auto_cast
(
use_pure_fp16
,
level
=
'O2'
,
dtype
=
'bfloat16'
if
use_bfp16
else
'float16'
,
):
out
=
model
(
img
)
out
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
input
=
out
,
label
=
label
...
@@ -161,7 +171,9 @@ def train_mlp(
...
@@ -161,7 +171,9 @@ def train_mlp(
def
test_stage3_offload
():
def
test_stage3_offload
():
paddle
.
distributed
.
init_parallel_env
()
paddle
.
distributed
.
init_parallel_env
()
mlp
,
mlp1
,
mlp2
,
mlp3
,
mlp4
,
mlp5
,
mlp6
=
(
mlp
,
mlp1
,
mlp2
,
mlp3
,
mlp4
,
mlp5
,
mlp6
,
mlp7
,
mlp8
=
(
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(),
...
@@ -177,6 +189,8 @@ def test_stage3_offload():
...
@@ -177,6 +189,8 @@ def test_stage3_offload():
mlp4
.
set_state_dict
(
state_dict
)
mlp4
.
set_state_dict
(
state_dict
)
mlp5
.
set_state_dict
(
state_dict
)
mlp5
.
set_state_dict
(
state_dict
)
mlp6
.
set_state_dict
(
state_dict
)
mlp6
.
set_state_dict
(
state_dict
)
mlp7
.
set_state_dict
(
state_dict
)
mlp8
.
set_state_dict
(
state_dict
)
# fp32 offload
# fp32 offload
stage3_params
=
train_mlp
(
mlp1
,
use_pure_fp16
=
False
)
stage3_params
=
train_mlp
(
mlp1
,
use_pure_fp16
=
False
)
...
@@ -200,6 +214,31 @@ def test_stage3_offload():
...
@@ -200,6 +214,31 @@ def test_stage3_offload():
atol
=
1e-2
,
atol
=
1e-2
,
)
)
# bfp16 offload
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str
=
subprocess
.
check_output
(
r
"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'"
,
stderr
=
subprocess
.
DEVNULL
,
shell
=
True
,
).
decode
(
'utf-8'
)
nccl_version
=
(
int
(
""
.
join
(
nccl_version_str
.
split
(
"."
)))
if
nccl_version_str
else
0
)
if
nccl_version
>=
2100
:
stage3_params
=
train_mlp
(
mlp7
,
use_pure_fp16
=
True
,
use_bfp16
=
True
)
stage3_params_offload
=
train_mlp
(
mlp8
,
use_pure_fp16
=
True
,
offload
=
True
,
use_bfp16
=
True
)
for
i
in
range
(
len
(
stage3_params
)):
np
.
testing
.
assert_allclose
(
stage3_params
[
i
].
numpy
(),
stage3_params_offload
[
i
].
numpy
(),
rtol
=
1e-2
,
atol
=
1e-2
,
)
# fp32 accumulate grad offload
# fp32 accumulate grad offload
stage3_params
=
train_mlp
(
stage3_params
=
train_mlp
(
mlp5
,
use_pure_fp16
=
False
,
batch_size
=
20
,
accumulate_grad
=
True
mlp5
,
use_pure_fp16
=
False
,
batch_size
=
20
,
accumulate_grad
=
True
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录