Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
23d559dd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
23d559dd
编写于
2月 08, 2022
作者:
B
Baibaifan
提交者:
GitHub
2月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize sharding stage3 (#39334)
上级
41eb2595
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
178 addition
and
72 deletion
+178
-72
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py
...stributed/fleet/meta_parallel/sharding/sharding_stage3.py
+133
-64
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py
...istributed/fleet/meta_parallel/sharding/sharding_utils.py
+26
-4
python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
...n/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
+5
-1
python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py
.../fluid/tests/unittests/dygraph_sharding_stage3_offload.py
+14
-3
未找到文件。
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py
浏览文件 @
23d559dd
...
...
@@ -35,6 +35,7 @@ from paddle.distributed.collective import _get_global_group
from
.sharding_utils
import
Type
,
ShardingClipGrad
,
device_guard
from
..pp_utils.utils
import
_all_gather
from
...utils.internal_storage
import
GradStorage
# CUDA alignment 256 bytes
alignment
=
{
"gpu"
:
256
,
}
...
...
@@ -69,6 +70,7 @@ class ShardingStage3(nn.Layer):
group
=
None
,
sync_buffers
=
False
,
device
=
"gpu"
,
segment_size
=
2
**
15
,
pertrain_sync_models
=
True
,
accumulate_grads
=
False
,
offload
=
False
,
...
...
@@ -83,6 +85,8 @@ class ShardingStage3(nn.Layer):
self
.
_accumulate_grads
=
accumulate_grads
self
.
_offload
=
offload
self
.
_sync_comm
=
sync_comm
# segmentation size
self
.
_segment_size
=
segment_size
if
not
offload
else
0
global
DEV
DEV
=
"cpu"
if
paddle
.
get_device
()
==
"cpu"
else
paddle
.
get_device
(
...
...
@@ -107,7 +111,10 @@ class ShardingStage3(nn.Layer):
self
.
_param2buffer_size
=
dict
()
# {param.name: size}
self
.
_param2buffer
=
dict
(
)
# {param.name: [(start0, end0),(start1, end1), ...]}
self
.
_trainable_params
=
dict
()
# {layer.name: [trainable_params]}
self
.
_trainable_params
=
dict
()
# {id(layer): [trainable_params]}
self
.
_unslice_params
=
set
()
# param's numel <= segment_size
self
.
_unslice_params2align
=
dict
()
# {param.name: param's align}
self
.
_grad_storages
=
dict
()
# {param.dtype: GradStorage}
assert
not
isinstance
(
optimizer
,
list
),
"Multiple optimizers are not supported now."
...
...
@@ -131,10 +138,13 @@ class ShardingStage3(nn.Layer):
self
.
_segment_rank_params
(
self
.
_layer
)
# Add unslice params to master_weight in fp16
self
.
_handle_unslice_params
()
# In the first step, record the execution order of the layer
self
.
_order_tracer
=
OrderedDict
()
self
.
_order_tracer
[
"order"
]
=
0
self
.
_order_tracer
[
"layer"
]
=
[]
self
.
_order_tracer
[
"layer"
]
=
list
()
# Register task flow
self
.
_task_flow
=
TaskFlow
()
...
...
@@ -168,8 +178,10 @@ class ShardingStage3(nn.Layer):
def
_clear_gradients
(
self
):
assert
len
(
self
.
_trainable_params
.
keys
())
>
0
current_layer_params
=
self
.
_layer
.
parameters
(
include_sublayers
=
True
)
# 1.Handle param's slice
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
trainable
,
current_layer_params
))
filter
(
lambda
p
:
p
.
trainable
and
p
not
in
self
.
_unslice_params
,
current_layer_params
))
for
param
in
trainable_params
:
assert
hasattr
(
param
,
"fw_storage"
...
...
@@ -178,6 +190,9 @@ class ShardingStage3(nn.Layer):
param
.
fw_storage
.
clear_gradient
(
False
)
param
.
fw_storage
.
_gradient_set_empty
(
False
)
param
.
bw_storage
.
_clear
()
# 2.Handle unslice param
for
grad_storage
in
self
.
_grad_storages
.
values
():
grad_storage
.
buffer
.
zero_
()
# Update param memery slice
def
_update_params_slice
(
self
):
...
...
@@ -185,20 +200,25 @@ class ShardingStage3(nn.Layer):
if
not
isinstance
(
self
.
_optim
.
_param_groups
[
0
],
dict
):
slice_params
=
[
param
.
fw_storage
for
param
in
update_list
]
self
.
_optim
.
_parameter_list
=
slice_params
self
.
_optim
.
_param_groups
=
slice_params
self
.
_optim
.
_parameter_list
=
slice_params
+
list
(
self
.
_unslice_params
)
self
.
_optim
.
_param_groups
=
slice_params
+
list
(
self
.
_unslice_params
)
else
:
params_name_list
=
list
(
map
(
lambda
p
:
p
.
name
,
update_list
))
fw_storage_name_list
=
list
(
map
(
lambda
p
:
p
.
fw_storage
.
name
,
update_list
))
for
param_group
in
self
.
_optim
.
_param_groups
:
slice_
p
=
[]
p_grou
p
=
[]
for
p
in
param_group
[
'params'
]:
if
p
.
name
in
params_name_list
:
assert
hasattr
(
p
,
"fw_storage"
),
"Find {} don't have fw_storage attribute."
.
format
(
p
.
name
)
slice_p
.
append
(
p
.
fw_storage
)
param_group
[
'params'
]
=
slice_p
p_group
.
append
(
p
.
fw_storage
)
elif
p
.
name
in
fw_storage_name_list
:
p_group
.
append
(
update_list
[
fw_storage_name_list
.
index
(
p
.
name
)].
fw_storage
)
elif
p
in
self
.
_unslice_params
:
p_group
.
append
(
p
)
param_group
[
'params'
]
=
p_group
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""
...
...
@@ -213,6 +233,32 @@ class ShardingStage3(nn.Layer):
return
fw
def
_handle_unslice_params
(
self
):
buffer_size
=
dict
()
buffer_size
[
Type
.
fp32
.
value
]
=
0
buffer_size
[
Type
.
fp16
.
value
]
=
0
for
param
in
self
.
_unslice_params
:
# Updata optimizer master weights
if
param
.
dtype
==
Type
.
fp16
.
value
and
not
self
.
_offload
:
self
.
_optim
.
_master_weights
[
param
.
name
]
=
paddle
.
cast
(
param
,
Type
.
fp32
.
value
)
param2dtype
[
param
.
name
]
=
param
.
dtype
p_align
=
self
.
_param2align
(
param
)
self
.
_unslice_params2align
[
param
.
name
]
=
p_align
buffer_size
[
param
.
dtype
]
+=
param
.
_numel
()
+
p_align
# Create unslice_params'grad
for
param
in
sorted
(
list
(
self
.
_unslice_params
),
key
=
lambda
p
:
p
.
name
):
if
param
.
dtype
not
in
self
.
_grad_storages
.
keys
():
self
.
_grad_storages
[
param
.
dtype
]
=
GradStorage
(
buffer_size
[
param
.
dtype
],
dtype
=
param
.
dtype
,
device
=
self
.
_default_device
,
destination
=
self
.
_rank
,
parm2align
=
self
.
_unslice_params2align
)
self
.
_grad_storages
[
param
.
dtype
].
add_grad
(
param
,
self
.
_unslice_params2align
[
param
.
name
])
def
_segment_rank_params
(
self
,
layer
,
name
=
"last_layer"
):
"""
Flatten parameters according to layer.
...
...
@@ -233,24 +279,22 @@ class ShardingStage3(nn.Layer):
def
_add_manage_info
(
trainable_param
):
return
_PartitionParam
(
trainable_param
)
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
trainable
,
current_layer_params
))
current_params
=
list
()
for
p
in
current_layer_params
:
if
p
.
trainable
and
p
.
_numel
()
>
self
.
_segment_size
:
current_params
.
append
(
_add_manage_info
(
p
))
elif
p
.
trainable
:
self
.
_unslice_params
.
add
(
_UnsliceParam
(
p
))
assert
id
(
layer
)
not
in
self
.
_trainable_params
.
keys
()
self
.
_trainable_params
[
id
(
layer
)]
=
list
(
map
(
_add_manage_info
,
trainable_params
))
self
.
_trainable_params
[
id
(
layer
)]
=
current_params
for
param
in
self
.
_trainable_params
[
id
(
layer
)]:
if
param
.
name
in
self
.
_param2buffer
.
keys
():
continue
self
.
_param2buffer
[
param
.
name
]
=
[]
# 1.Params alignment
offset
=
0
# CUDA alignment 256 bytes
size
=
param
.
_numel
()
*
align
[
param
.
dtype
]
remaining
=
size
%
alignment
[
self
.
_default_device
]
ali
=
0
if
remaining
==
0
else
alignment
[
self
.
_default_device
]
-
remaining
align_
=
ali
//
align
[
param
.
dtype
]
align_
=
self
.
_param2align
(
param
)
offset
=
align_
+
param
.
_numel
()
buffer_size
=
offset
if
offset
%
self
.
_group
.
nranks
==
0
else
offset
+
self
.
_group
.
nranks
-
(
...
...
@@ -379,7 +423,9 @@ class ShardingStage3(nn.Layer):
assert
len
(
self
.
_trainable_params
.
keys
())
>
0
current_layer_params
=
self
.
_layer
.
parameters
(
include_sublayers
=
True
)
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
trainable
,
current_layer_params
))
filter
(
lambda
p
:
p
.
trainable
and
p
not
in
self
.
_unslice_params
,
current_layer_params
))
# 1.Handle param's slice
for
param
in
trainable_params
:
assert
hasattr
(
param
,
...
...
@@ -396,6 +442,19 @@ class ShardingStage3(nn.Layer):
assert
param
.
fw_storage
.
grad
is
None
param
.
fw_storage
.
_copy_gradient_from
(
param
.
bw_storage
)
update_list
.
append
(
param
)
# 2.Handle unslice param
for
grad_storage
in
self
.
_grad_storages
.
values
():
grad_storage
.
buffer
.
scale_
(
scale
=
self
.
_world_size_scaling
)
dist
.
all_reduce
(
tensor
=
grad_storage
.
buffer
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
dist
.
wait
(
tensor
=
grad_storage
.
buffer
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
return
update_list
def
get_all_parameters
(
self
,
convert2cpu
=
False
):
...
...
@@ -405,7 +464,8 @@ class ShardingStage3(nn.Layer):
assert
len
(
self
.
_trainable_params
.
keys
())
>
0
current_layer_params
=
self
.
_layer
.
parameters
(
include_sublayers
=
True
)
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
trainable
,
current_layer_params
))
filter
(
lambda
p
:
p
.
trainable
and
p
not
in
self
.
_unslice_params
,
current_layer_params
))
t_flow
=
_allgather_buffer
(
trainable_params
,
self
.
_group
,
...
...
@@ -415,7 +475,7 @@ class ShardingStage3(nn.Layer):
offload
=
self
.
_offload
,
convert2cpu
=
convert2cpu
)
if
convert2cpu
:
for
param
in
current_layer
_params
:
for
param
in
trainable
_params
:
t_flow
.
full_param
[
param
.
name
].
_share_buffer_to
(
param
)
self
.
_optim
.
_parameter_list
=
self
.
_ori_parameter_list
...
...
@@ -424,7 +484,8 @@ class ShardingStage3(nn.Layer):
def
_register_backward_hooks
(
self
):
current_layer_params
=
self
.
_layer
.
parameters
(
include_sublayers
=
True
)
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
trainable
,
current_layer_params
))
filter
(
lambda
p
:
p
.
trainable
and
p
not
in
self
.
_unslice_params
,
current_layer_params
))
for
param
in
trainable_params
:
allreduce_function
=
self
.
_get_allreduce_fn
(
param
)
...
...
@@ -435,42 +496,36 @@ class ShardingStage3(nn.Layer):
def
reduce
(
*
_
):
if
param
.
name
in
self
.
_task_flow
.
full_grad
.
keys
():
full_grad
=
self
.
_task_flow
.
full_grad
[
param
.
name
]
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
if
not
self
.
_accumulate_grads
:
full_grad
.
scale_
(
scale
=
self
.
_world_size_scaling
)
# Only support sync allreduce current rank's layer now
dist
.
all_reduce
(
tensor
=
full_grad
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
dist
.
wait
(
tensor
=
full_grad
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
if
not
self
.
_accumulate_grads
:
full_grad
.
scale_
(
scale
=
self
.
_world_size_scaling
)
# Only support sync allreduce current rank's layer now
dist
.
all_reduce
(
tensor
=
full_grad
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
dist
.
wait
(
tensor
=
full_grad
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
start
,
end
=
self
.
_param2buffer
[
param
.
name
][
self
.
_rank
]
if
not
self
.
_accumulate_grads
or
param
.
bw_storage
is
None
or
not
param
.
bw_storage
.
value
(
).
get_tensor
().
_is_initialized
():
param
.
bw_storage
=
core
.
VarBase
(
full_grad
.
_slice
(
start
,
end
)).
detach
().
clone
()
if
self
.
_offload
:
param
.
bw_storage
=
_device2cpu
(
param
.
bw_storage
,
True
)
start
,
end
=
self
.
_param2buffer
[
param
.
name
][
self
.
_rank
]
if
not
self
.
_accumulate_grads
or
param
.
bw_storage
is
None
or
not
param
.
bw_storage
.
value
(
).
get_tensor
().
_is_initialized
():
param
.
bw_storage
=
core
.
VarBase
(
full_grad
.
_slice
(
start
,
end
)).
detach
().
clone
()
if
self
.
_offload
:
param
.
bw_storage
=
_device2cpu
(
param
.
bw_storage
,
True
)
else
:
if
self
.
_offload
:
cpu_grad
=
_device2cpu
(
core
.
VarBase
(
full_grad
.
_slice
(
start
,
end
))
.
detach
().
clone
(),
True
)
param
.
bw_storage
=
paddle
.
add
(
param
.
bw_storage
,
cpu_grad
)
else
:
if
self
.
_offload
:
cpu_grad
=
_device2cpu
(
core
.
VarBase
(
full_grad
.
_slice
(
start
,
end
))
.
detach
().
clone
(),
True
)
param
.
bw_storage
=
paddle
.
add
(
param
.
bw_storage
,
cpu_grad
)
else
:
# param.bw_storage.add_(
# core.VarBase(full_grad._slice(start, end))
# .detach().clone())
param
.
bw_storage
=
paddle
.
add
(
param
.
bw_storage
,
core
.
VarBase
(
full_grad
.
_slice
(
start
,
end
)).
detach
().
clone
())
# param.bw_storage.add_(
# core.VarBase(full_grad._slice(start, end))
# .detach().clone())
param
.
bw_storage
=
paddle
.
add
(
param
.
bw_storage
,
core
.
VarBase
(
full_grad
.
_slice
(
start
,
end
)).
detach
(
).
clone
())
param
.
clear_gradient
(
False
)
param
.
_gradient_set_empty
(
False
)
tmp_var
=
self
.
_task_flow
.
full_grad
.
pop
(
param
.
name
)
...
...
@@ -493,6 +548,15 @@ class ShardingStage3(nn.Layer):
return
reduce
def
_param2align
(
self
,
param
):
# CUDA alignment 256 bytes
size
=
param
.
_numel
()
*
align
[
param
.
dtype
]
remaining
=
size
%
alignment
[
self
.
_default_device
]
ali
=
0
if
remaining
==
0
else
alignment
[
self
.
_default_device
]
-
remaining
align_
=
ali
//
align
[
param
.
dtype
]
return
align_
def
_redefine_opt_step
(
self
):
params_slice_func
=
self
.
_update_params_slice
opt_step
=
self
.
_optim
.
step
...
...
@@ -679,14 +743,13 @@ def _wait_layer(trainable_params,
group
,
use_calc_stream
,
offload
=
False
):
paddle
.
device
.
cuda
.
synchronize
()
for
param
in
trainable_params
:
if
param
.
status
==
"all"
:
param
.
use_count
+=
1
continue
if
param
.
name
in
task_flow
.
full_param
.
keys
():
full_param
=
task_flow
.
full_param
[
param
.
name
]
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
paddle
.
device
.
cuda
.
synchronize
()
core
.
VarBase
(
full_param
.
_slice
(
0
,
param
.
_numel
())).
_share_buffer_to
(
param
)
param
.
fw_storage
.
_clear
()
...
...
@@ -725,7 +788,7 @@ def _allgather_buffer(trainable_params,
full_param
=
_all_gather
(
param
.
fw_storage
,
group
,
use_calc_stream
=
use_calc_stream
)
# Allgather current layer in the 1st step
# Allgather current layer in the 1st step
synchronously
if
sync_wait
:
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
dist
.
wait
(
...
...
@@ -774,6 +837,12 @@ def _PartitionParam(param):
return
param
def
_UnsliceParam
(
param
):
if
not
hasattr
(
param
,
"unslice"
):
setattr
(
param
,
"unslice"
,
True
)
return
param
def
_VarBaseWrapper
(
param
):
varbase
=
param
.
fw_storage
tmp_param
=
ParamBase
(
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py
浏览文件 @
23d559dd
...
...
@@ -57,12 +57,15 @@ class ShardingClipGrad:
@
imperative_base
.
no_grad
def
_dygraph_clip
(
self
,
params_grads
):
sum_square_fp
16
=
[]
sum_square_fp32
=
[]
sum_square_fp
32
,
sum_square_fp16
=
[],
[]
unslice_params_fp32
,
unslice_params_fp16
=
[],
[]
for
p
,
g
in
params_grads
:
p_slice
=
True
# using for slice parameter in sharding stage3
if
g
is
None
or
getattr
(
p
,
'need_clip'
,
True
)
is
False
:
continue
if
hasattr
(
p
,
"unslice"
):
p_slice
=
False
merge_grad
=
g
if
g
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
...
...
@@ -72,9 +75,11 @@ class ShardingClipGrad:
sum_square
=
layers
.
reduce_sum
(
square
)
if
p
.
dtype
==
paddle
.
float16
:
sum_square_fp16
.
append
(
sum_square
)
if
p_slice
:
sum_square_fp16
.
append
(
sum_square
)
else
:
unslice_params_fp16
.
append
(
sum_square
)
elif
p
.
dtype
==
paddle
.
float32
:
sum_square_fp32
.
append
(
sum_square
)
if
p_slice
:
sum_square_fp32
.
append
(
sum_square
)
else
:
unslice_params_fp32
.
append
(
sum_square
)
# global norm of non-distributed FP16 params_and_grads
if
len
(
sum_square_fp16
)
==
0
:
...
...
@@ -85,12 +90,28 @@ class ShardingClipGrad:
global_norm_fp16
=
paddle
.
cast
(
global_norm_fp16
,
dtype
=
paddle
.
float32
)
# global norm of non-distributed FP16 params_and_grads for slice parameter
if
len
(
unslice_params_fp16
)
==
0
:
global_unslice_fp16
=
paddle
.
to_tensor
([
0.
],
dtype
=
paddle
.
float32
)
else
:
global_unslice_fp16
=
layers
.
concat
(
unslice_params_fp16
)
global_unslice_fp16
=
layers
.
reduce_sum
(
global_unslice_fp16
)
global_unslice_fp16
=
paddle
.
cast
(
global_unslice_fp16
,
dtype
=
paddle
.
float32
)
# global norm of non-distributed FP32 params_and_grads
global_norm_fp32
=
layers
.
concat
(
sum_square_fp32
)
if
len
(
sum_square_fp32
)
!=
0
else
paddle
.
to_tensor
(
[
0.
],
dtype
=
paddle
.
float32
)
global_norm_fp32
=
layers
.
reduce_sum
(
global_norm_fp32
)
# global norm of non-distributed FP32 params_and_grads for slice parameter
global_unslice_fp32
=
layers
.
concat
(
unslice_params_fp32
)
if
len
(
unslice_params_fp32
)
!=
0
else
paddle
.
to_tensor
(
[
0.
],
dtype
=
paddle
.
float32
)
global_unslice_fp32
=
layers
.
reduce_sum
(
global_unslice_fp32
)
global_unslice_var
=
global_unslice_fp16
+
global_unslice_fp32
global_norm_var
=
global_norm_fp16
+
global_norm_fp32
# add all reduce to get global norm of distributed params_and_grads
...
...
@@ -98,6 +119,7 @@ class ShardingClipGrad:
with
device_guard
(
dev_id
,
"gpu"
):
paddle
.
distributed
.
all_reduce
(
global_norm_var
,
group
=
self
.
_group
)
global_norm_var
+=
global_unslice_var
global_norm_var
=
layers
.
sqrt
(
global_norm_var
)
max_global_norm
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
global_norm_var
.
dtype
,
value
=
self
.
clip_norm
)
...
...
python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
浏览文件 @
23d559dd
...
...
@@ -145,6 +145,10 @@ def train_mlp(model,
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
)
avg_loss
=
paddle
.
mean
(
x
=
loss
.
cast
(
dtype
=
paddle
.
float32
))
if
batch_size
==
20
:
avg_loss
=
avg_loss
/
5
if
not
use_pure_fp16
:
avg_loss
.
backward
()
else
:
...
...
@@ -215,7 +219,7 @@ def test_stage2_stage3():
stage3_params
[
i
].
numpy
(),
stage3_params_add
[
i
].
numpy
(),
rtol
=
1e-6
,
atol
=
1e-
6
)
atol
=
1e-
4
)
# fp16
stage2_params
=
train_mlp
(
...
...
python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py
浏览文件 @
23d559dd
...
...
@@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar
from
paddle.distributed.fleet.meta_parallel.sharding.sharding_utils
import
ShardingScaler
epoch
=
10
batch_size
=
32
paddle
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
base_lr
=
0.1
...
...
@@ -80,6 +79,7 @@ def train_mlp(model,
use_pure_fp16
=
False
,
accumulate_grad
=
False
,
offload
=
False
,
batch_size
=
100
,
convert2cpu
=
False
):
group
=
paddle
.
distributed
.
new_group
([
0
,
1
])
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
)
...
...
@@ -91,7 +91,11 @@ def train_mlp(model,
scaler
=
ShardingScaler
(
scaler
)
model
=
ShardingStage3
(
model
,
optimizer
=
optimizer
,
group
=
group
,
offload
=
offload
)
model
,
optimizer
=
optimizer
,
group
=
group
,
offload
=
offload
,
accumulate_grads
=
accumulate_grad
)
train_reader
=
paddle
.
batch
(
reader_decorator
(),
batch_size
=
batch_size
,
drop_last
=
True
)
...
...
@@ -115,10 +119,15 @@ def train_mlp(model,
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
)
avg_loss
=
paddle
.
mean
(
x
=
loss
.
cast
(
dtype
=
paddle
.
float32
))
if
accumulate_grad
:
avg_loss
=
avg_loss
/
5
if
not
use_pure_fp16
:
avg_loss
.
backward
()
else
:
scaler
.
scale
(
avg_loss
).
backward
()
if
not
accumulate_grad
:
if
not
use_pure_fp16
:
optimizer
.
step
()
...
...
@@ -172,12 +181,14 @@ def test_stage3_offload():
atol
=
1e-2
)
# fp32 accumulate grad offload
stage3_params
=
train_mlp
(
mlp5
,
use_pure_fp16
=
False
,
accumulate_grad
=
True
)
stage3_params
=
train_mlp
(
mlp5
,
use_pure_fp16
=
False
,
batch_size
=
20
,
accumulate_grad
=
True
)
stage3_params_offload
=
train_mlp
(
mlp6
,
use_pure_fp16
=
False
,
accumulate_grad
=
True
,
offload
=
True
,
batch_size
=
20
,
convert2cpu
=
True
)
for
i
in
range
(
len
(
stage3_params
)):
np
.
testing
.
assert_allclose
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录