Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
月光在发光
Paddle
提交
46823104
P
Paddle
项目概览
月光在发光
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
46823104
编写于
1月 24, 2022
作者:
B
Baibaifan
提交者:
GitHub
1月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add sharding stage3 offload (#38989)
上级
f4623876
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
449 addition
and
112 deletion
+449
-112
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py
...stributed/fleet/meta_parallel/sharding/sharding_stage3.py
+208
-69
python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
...n/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
+43
-40
python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py
.../fluid/tests/unittests/dygraph_sharding_stage3_offload.py
+192
-0
python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py
...dle/fluid/tests/unittests/test_dygraph_sharding_stage2.py
+2
-2
python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py
...dle/fluid/tests/unittests/test_dygraph_sharding_stage3.py
+4
-1
未找到文件。
python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py
浏览文件 @
46823104
...
...
@@ -33,7 +33,7 @@ from paddle.fluid.framework import ParamBase
from
paddle.fluid.clip
import
ClipGradByGlobalNorm
from
paddle.distributed.collective
import
_get_global_group
from
.sharding_utils
import
Type
,
ShardingClipGrad
from
.sharding_utils
import
Type
,
ShardingClipGrad
,
device_guard
from
..pp_utils.utils
import
_all_gather
# CUDA alignment 256 bytes
...
...
@@ -56,6 +56,13 @@ class ShardingStage3(nn.Layer):
.. ZeRO: https://arxiv.org/pdf/1910.02054.pdf.
"""
# TODO (Baibaifan)
# Feature Notes::
# 1. The model supports the segmentation of parameters by global ranks in layers.
# 2. Support communication flow and computing flow.
# 3. Support offload function.
# 4. Support the establishment of independent communication groups.
def
__init__
(
self
,
layer
,
optimizer
,
...
...
@@ -77,6 +84,15 @@ class ShardingStage3(nn.Layer):
self
.
_offload
=
offload
self
.
_sync_comm
=
sync_comm
global
DEV
DEV
=
"cpu"
if
paddle
.
get_device
()
==
"cpu"
else
paddle
.
get_device
(
).
split
(
":"
)[
0
]
global
DEV_ID
DEV_ID
=
0
if
paddle
.
get_device
()
==
"cpu"
else
int
(
paddle
.
get_device
()
.
split
(
":"
)[
1
])
global
param2dtype
param2dtype
=
dict
()
# Communication group establishment
self
.
_group
=
dist
.
new_group
(
_get_global_group
()
.
ranks
)
if
group
is
None
else
group
...
...
@@ -85,6 +101,9 @@ class ShardingStage3(nn.Layer):
self
.
_rank
=
self
.
_group
.
rank
self
.
_global_root_rank
=
0
# picking rank 0 as the reference
self
.
_global_ranks
=
self
.
_group
.
ranks
# Parameter segmentation for global ranks
# After flatten -> self._param2buffer_size, self._param2buffer, self._trainable_params
self
.
_param2buffer_size
=
dict
()
# {param.name: size}
self
.
_param2buffer
=
dict
(
)
# {param.name: [(start0, end0),(start1, end1), ...]}
...
...
@@ -116,12 +135,16 @@ class ShardingStage3(nn.Layer):
self
.
_order_tracer
=
OrderedDict
()
self
.
_order_tracer
[
"order"
]
=
0
self
.
_order_tracer
[
"layer"
]
=
[]
# Register task flow
self
.
_task_flow
=
TaskFlow
()
# Register forward hooks
self
.
_register_forward_hooks
(
self
.
_layer
)
# Register backward parameter hooks
self
.
_register_backward_hooks
()
# Redefine optimizer step and clear function
self
.
_redefine_opt_step
()
self
.
_redefine_opt_clear
()
...
...
@@ -152,7 +175,6 @@ class ShardingStage3(nn.Layer):
param
,
"fw_storage"
),
"Find {} don't have fw_storage attribute."
.
format
(
param
.
name
)
# param.bw_storage.zero_()
param
.
fw_storage
.
clear_gradient
(
False
)
param
.
fw_storage
.
_gradient_set_empty
(
False
)
param
.
bw_storage
.
_clear
()
...
...
@@ -192,6 +214,9 @@ class ShardingStage3(nn.Layer):
return
fw
def
_segment_rank_params
(
self
,
layer
,
name
=
"last_layer"
):
"""
Flatten parameters according to layer.
"""
current_layer_params
=
_current_layer_params
(
layer
)
if
current_layer_params
:
CHECK_LAYER
[
id
(
layer
)]
=
name
...
...
@@ -201,6 +226,10 @@ class ShardingStage3(nn.Layer):
self
.
_segment_rank_params
(
sub_layer
,
name
)
def
_flatten_layer_params
(
self
,
layer
,
current_layer_params
):
"""
Parameter segmentation and memory integration.
"""
def
_add_manage_info
(
trainable_param
):
return
_PartitionParam
(
trainable_param
)
...
...
@@ -238,8 +267,13 @@ class ShardingStage3(nn.Layer):
# 3.Flatten layer params and release other rank buffer
self
.
_param_storage
(
param
,
buffer_size
)
# Record param's dtype
param2dtype
[
param
.
name
]
=
param
.
dtype
def
_param_storage
(
self
,
param
,
buffer_size
):
"""
This is a function to simplify the handling of parameter InternalStorages.
"""
assert
isinstance
(
buffer_size
,
int
)
value
=
np
.
zeros
(
buffer_size
,
...
...
@@ -264,16 +298,31 @@ class ShardingStage3(nn.Layer):
param
.
_clear
()
# Current rank param_storage
param
.
fw_storage
=
core
.
VarBase
(
buffer
.
_slice
(
start
,
end
),
"slice@"
+
param
.
name
)
if
self
.
_offload
:
param
.
fw_storage
=
core
.
VarBase
(
buffer
.
_slice
(
start
,
end
),
core
.
CPUPlace
(),
"slice@"
+
param
.
name
)
else
:
param
.
fw_storage
=
core
.
VarBase
(
buffer
.
_slice
(
start
,
end
),
"slice@"
+
param
.
name
)
param
.
status
=
"part"
# Updata optimizer master weights
if
param
.
dtype
==
Type
.
fp16
.
value
:
if
param
.
dtype
==
Type
.
fp16
.
value
and
not
self
.
_offload
:
self
.
_optim
.
_master_weights
[
param
.
fw_storage
.
name
]
=
paddle
.
cast
(
param
.
fw_storage
,
Type
.
fp32
.
value
)
def
_register_forward_hooks
(
self
,
layer
):
"""
Register pylayer to manage memory slices.
There are four stages:
FW
1. Before the forward layers, synchronize the full parameters.
2. After the forward layers, release the full parameter and keep the parameter slice.
BW
3. Before the backward layers, synchronize the full parameters and create param's grad.
4. After the gradient accumulation, release the full parameter and keep the parameter slice.
"""
current_layer_params
=
_current_layer_params
(
layer
)
if
current_layer_params
:
self
.
_register_forward_all_hooks
(
layer
,
self
.
_task_flow
)
...
...
@@ -286,13 +335,13 @@ class ShardingStage3(nn.Layer):
return
ForwardPreHooks
(
layer
,
self
.
_order_tracer
,
self
.
_trainable_params
,
self
.
_param2buffer
,
self
.
_rank
,
self
.
_group
,
self
.
_sync_comm
,
task_flow
)
self
.
_offload
,
task_flow
)
def
_forward_post_hook
(
layer
,
inputs
,
outputs
):
return
ForwardPostHooks
.
apply
(
outputs
,
layer
,
self
.
_order_tracer
,
self
.
_trainable_params
,
self
.
_param2buffer
,
self
.
_param2buffer_size
,
self
.
_rank
,
self
.
_group
,
self
.
_sync_comm
,
task_flow
)
self
.
_group
,
self
.
_sync_comm
,
self
.
_offload
,
task_flow
)
# register previous forward hooks
sub_layer
.
register_forward_pre_hook
(
_forward_pre_hook
)
...
...
@@ -302,6 +351,10 @@ class ShardingStage3(nn.Layer):
@
paddle
.
no_grad
()
def
_sync_buffers
(
self
):
"""
Sync all the param buffers from all ranks (exp: batch norm statistics).
"""
for
buffer
in
self
.
_layer
.
buffers
(
include_sublayers
=
True
):
dist
.
broadcast
(
buffer
,
...
...
@@ -319,6 +372,9 @@ class ShardingStage3(nn.Layer):
return
getattr
(
self
.
_layer
,
name
)
def
_update_params
(
self
):
"""
Update parameters to optimizer memory slice.
"""
update_list
=
[]
assert
len
(
self
.
_trainable_params
.
keys
())
>
0
current_layer_params
=
self
.
_layer
.
parameters
(
include_sublayers
=
True
)
...
...
@@ -331,36 +387,35 @@ class ShardingStage3(nn.Layer):
param
.
name
)
if
self
.
_accumulate_grads
:
param
.
bw_storage
.
scale_
(
scale
=
self
.
_world_size_scaling
)
if
self
.
_offload
:
with
device_guard
(
device
=
"cpu"
):
param
.
bw_storage
.
scale_
(
scale
=
self
.
_world_size_scaling
)
else
:
param
.
bw_storage
.
scale_
(
scale
=
self
.
_world_size_scaling
)
param
.
fw_storage
=
_VarBaseWrapper
(
param
)
param
.
fw_storage
.
_copy_gradient_from
(
param
.
bw_storage
)
update_list
.
append
(
param
)
return
update_list
def
get_all_parameters
(
self
):
def
get_all_parameters
(
self
,
convert2cpu
=
False
):
"""
Get the full parameters and return the corresponding task flows.
"""
assert
len
(
self
.
_trainable_params
.
keys
())
>
0
current_layer_params
=
self
.
_layer
.
parameters
(
include_sublayers
=
True
)
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
trainable
,
current_layer_params
))
for
param
in
trainable_params
:
if
param
.
use_count
>
0
:
continue
assert
hasattr
(
param
,
"fw_storage"
),
"Find {} don't have fw_storage attribute"
.
format
(
param
.
name
)
full_param
=
_all_gather
(
param
.
fw_storage
,
self
.
_group
,
use_calc_stream
=
True
)
dist
.
wait
(
tensor
=
full_param
,
group
=
self
.
_group
,
use_calc_stream
=
True
)
core
.
VarBase
(
full_param
.
_slice
(
0
,
param
.
_numel
())).
_share_buffer_to
(
param
)
param
.
value
().
get_tensor
().
_set_dims
(
param
.
shape
)
param
.
fw_storage
.
_clear
()
param
.
fw_storage
=
None
param
.
status
=
"all"
param
.
use_count
+=
1
t_flow
=
_allgather_buffer
(
trainable_params
,
self
.
_group
,
use_calc_stream
=
True
,
task_flow
=
TaskFlow
(),
sync_wait
=
True
,
offload
=
self
.
_offload
,
convert2cpu
=
convert2cpu
)
if
convert2cpu
:
for
param
in
current_layer_params
:
t_flow
.
full_param
[
param
.
name
].
_share_buffer_to
(
param
)
self
.
_optim
.
_parameter_list
=
self
.
_ori_parameter_list
self
.
_optim
.
_param_groups
=
self
.
_ori_param_groups
...
...
@@ -393,13 +448,28 @@ class ShardingStage3(nn.Layer):
use_calc_stream
=
True
)
start
,
end
=
self
.
_param2buffer
[
param
.
name
][
self
.
_rank
]
if
not
self
.
_accumulate_grads
or
param
.
bw_storage
is
None
:
if
not
self
.
_accumulate_grads
or
param
.
bw_storage
is
None
or
not
param
.
bw_storage
.
value
(
).
get_tensor
().
_is_initialized
():
param
.
bw_storage
=
core
.
VarBase
(
full_grad
.
_slice
(
start
,
end
)).
detach
().
clone
()
if
self
.
_offload
:
param
.
bw_storage
=
_device2cpu
(
param
.
bw_storage
,
True
)
else
:
param
.
bw_storage
.
add_
(
core
.
VarBase
(
full_grad
.
_slice
(
start
,
end
)).
detach
()
.
clone
())
if
self
.
_offload
:
cpu_grad
=
_device2cpu
(
core
.
VarBase
(
full_grad
.
_slice
(
start
,
end
))
.
detach
().
clone
(),
True
)
param
.
bw_storage
=
paddle
.
add
(
param
.
bw_storage
,
cpu_grad
)
else
:
# param.bw_storage.add_(
# core.VarBase(full_grad._slice(start, end))
# .detach().clone())
param
.
bw_storage
=
paddle
.
add
(
param
.
bw_storage
,
core
.
VarBase
(
full_grad
.
_slice
(
start
,
end
)).
detach
().
clone
())
param
.
clear_gradient
(
False
)
param
.
_gradient_set_empty
(
False
)
tmp_var
=
self
.
_task_flow
.
full_grad
.
pop
(
param
.
name
)
...
...
@@ -410,15 +480,16 @@ class ShardingStage3(nn.Layer):
param
.
use_count
=
0
param
.
_clear
()
start
,
end
=
self
.
_param2buffer
[
param
.
name
][
self
.
_rank
]
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
param
.
fw_storage
=
core
.
VarBase
(
self
.
_task_flow
.
full_param
[
param
.
name
].
_slice
(
start
,
end
),
param
.
name
+
"@slice"
).
detach
().
clone
()
param
.
fw_storage
=
core
.
VarBase
(
self
.
_task_flow
.
full_param
[
param
.
name
].
_slice
(
start
,
end
),
param
.
name
+
"@slice"
).
detach
().
clone
()
param
.
status
=
"part"
tmp_var
=
self
.
_task_flow
.
full_param
.
pop
(
param
.
name
)
tmp_var
.
_clear
()
if
self
.
_offload
:
param
.
fw_storage
=
_device2cpu
(
param
.
fw_storage
,
True
)
return
reduce
def
_redefine_opt_step
(
self
):
...
...
@@ -429,7 +500,11 @@ class ShardingStage3(nn.Layer):
def
_opt_step
(
self
):
if
not
update_scaler
:
params_slice_func
()
opt_step
()
if
self
.
offload
:
with
device_guard
(
device
=
"cpu"
):
opt_step
()
else
:
opt_step
()
self
.
_optim
.
step
=
MethodType
(
_opt_step
,
self
.
_optim
)
...
...
@@ -443,7 +518,7 @@ class ShardingStage3(nn.Layer):
def
ForwardPreHooks
(
layer
,
order_tracer
,
trainable_params
,
param2buffer
,
rank
,
group
,
sync_comm
,
task_flow
):
group
,
sync_comm
,
offload
,
task_flow
):
# Record layer's id
layer_id
=
id
(
layer
)
...
...
@@ -451,21 +526,28 @@ def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank,
if
layer_id
not
in
order_tracer
.
keys
()
or
sync_comm
:
use_calc
,
sync_wait
=
True
,
True
# Whether to use calc stream
task_flow
.
use_calc
[
layer_id
]
=
use_calc
else
:
# Whether to use calc stream
task_flow
.
use_calc
[
layer_id
]
=
use_calc
_wait_layer
(
trainable_params
,
layer_id
,
task_flow
,
group
,
use_calc
)
# wait current layer params
_wait_layer
(
trainable_params
[
layer_id
],
task_flow
,
group
,
use_calc
,
offload
)
if
layer_id
==
order_tracer
[
"layer"
][
-
1
]:
return
order_
=
order_tracer
[
layer_id
]
layer_id
=
order_tracer
[
"layer"
][
order_
+
1
]
_allgather_buffer
(
layer_id
,
trainable_params
,
trainable_params
[
layer_id
],
group
,
use_calc_stream
=
use_calc
,
task_flow
=
task_flow
,
sync_wait
=
sync_wait
)
sync_wait
=
sync_wait
,
offload
=
offload
)
return
...
...
@@ -473,15 +555,20 @@ class ForwardPostHooks(PyLayer):
@
staticmethod
def
forward
(
ctx
,
inputs
,
layer
,
order_tracer
,
trainable_params
,
param2buffer
,
param2buffer_size
,
rank
,
group
,
sync_comm
,
task_flow
):
_release_param
(
layer
,
trainable_params
,
param2buffer
,
rank
,
task_flow
)
offload
,
task_flow
):
layer_id
=
id
(
layer
)
# release current layer full params
_release_param
(
trainable_params
[
layer_id
],
param2buffer
,
rank
,
task_flow
,
offload
)
if
layer_id
not
in
order_tracer
.
keys
():
order_
=
order_tracer
[
"order"
]
order_tracer
[
layer_id
]
=
order_
order_tracer
[
"order"
]
+=
1
order_tracer
[
"layer"
].
append
(
layer_id
)
#Record bw info
ctx
.
order_tracer
=
order_tracer
ctx
.
task_flow
=
task_flow
ctx
.
group
=
group
...
...
@@ -489,6 +576,7 @@ class ForwardPostHooks(PyLayer):
ctx
.
sync_comm
=
sync_comm
ctx
.
trainable_params
=
trainable_params
ctx
.
param2buffer_size
=
param2buffer_size
ctx
.
offload
=
offload
return
inputs
...
...
@@ -502,31 +590,39 @@ class ForwardPostHooks(PyLayer):
trainable_params
=
ctx
.
trainable_params
param2buffer_size
=
ctx
.
param2buffer_size
sync_comm
=
ctx
.
sync_comm
offload
=
ctx
.
offload
layer_id
=
id
(
layer
)
use_calc
,
sync_wait
=
False
,
False
# Allgather params synchronization
if
sync_comm
:
use_calc
,
sync_wait
=
True
,
True
_allgather_buffer
(
layer_id
,
trainable_params
,
trainable_params
[
layer_id
],
group
,
use_calc_stream
=
use_calc
,
task_flow
=
task_flow
,
sync_wait
=
sync_wait
)
sync_wait
=
sync_wait
,
offload
=
offload
)
else
:
_wait_layer
(
trainable_params
,
layer_id
,
task_flow
,
group
,
use_calc
)
_create_params_grad
(
layer
,
trainable_params
,
param2buffer_size
,
_wait_layer
(
trainable_params
[
layer_id
],
task_flow
,
group
,
use_calc
,
offload
)
# Create params's grad
_create_params_grad
(
trainable_params
[
layer_id
],
param2buffer_size
,
task_flow
)
# Whether to use calc stream
task_flow
.
use_calc
[
layer_id
]
=
use_calc
if
layer_id
!=
order_tracer
[
"layer"
][
0
]
and
not
sync_comm
:
layer_next_id
=
order_tracer
[
"layer"
][
order_tracer
[
layer_id
]
-
1
]
_allgather_buffer
(
layer_next_id
,
trainable_params
,
trainable_params
[
layer_next_id
],
group
,
use_calc_stream
=
use_calc
,
task_flow
=
task_flow
,
sync_wait
=
sync_wait
)
sync_wait
=
sync_wait
,
offload
=
offload
)
return
args
...
...
@@ -547,8 +643,12 @@ class TaskFlow:
self
.
callback
=
callback
def
_release_param
(
layer
,
trainable_params
,
param2buffer
,
rank
,
task_flow
):
for
param
in
trainable_params
[
id
(
layer
)]:
def
_release_param
(
trainable_params
,
param2buffer
,
rank
,
task_flow
,
offload
=
False
):
for
param
in
trainable_params
:
# async communicate share weight not clear
param
.
use_count
-=
1
if
param
.
use_count
==
0
:
...
...
@@ -562,11 +662,18 @@ def _release_param(layer, trainable_params, param2buffer, rank, task_flow):
param
.
status
=
"part"
tmp_var
=
task_flow
.
full_param
.
pop
(
param
.
name
)
tmp_var
.
_clear
()
if
offload
:
param
.
fw_storage
=
_device2cpu
(
param
.
fw_storage
)
return
def
_wait_layer
(
trainable_params
,
layer_id
,
task_flow
,
group
,
use_calc_stream
):
for
param
in
trainable_params
[
layer_id
]:
def
_wait_layer
(
trainable_params
,
task_flow
,
group
,
use_calc_stream
,
offload
=
False
):
for
param
in
trainable_params
:
if
param
.
status
==
"all"
:
param
.
use_count
+=
1
continue
...
...
@@ -576,36 +683,43 @@ def _wait_layer(trainable_params, layer_id, task_flow, group, use_calc_stream):
paddle
.
device
.
cuda
.
synchronize
()
core
.
VarBase
(
full_param
.
_slice
(
0
,
param
.
_numel
())).
_share_buffer_to
(
param
)
param
.
value
().
get_tensor
().
_set_dims
(
param
.
shape
)
param
.
fw_storage
.
_clear
()
param
.
fw_storage
=
None
param
.
status
=
"all"
param
.
use_count
+=
1
else
:
_allgather_buffer
(
layer_id
,
trainable_params
,
group
,
use_calc_stream
,
task_flow
,
sync_wait
=
True
)
use_calc_stream
=
True
,
task_flow
=
task_flow
,
sync_wait
=
True
,
offload
=
offload
)
break
return
task_flow
def
_allgather_buffer
(
layer_id
,
trainable_params
,
def
_allgather_buffer
(
trainable_params
,
group
,
use_calc_stream
,
task_flow
,
sync_wait
=
False
):
for
param
in
trainable_params
[
layer_id
]:
sync_wait
=
False
,
offload
=
False
,
convert2cpu
=
False
):
for
param
in
trainable_params
:
if
param
.
status
==
"all"
:
param
.
use_count
+=
1
continue
if
offload
:
param
.
fw_storage
=
_cpu2device
(
param
)
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
full_param
=
_all_gather
(
param
.
fw_storage
,
group
,
use_calc_stream
=
use_calc_stream
)
# Allgather current layer in the 1st step
if
sync_wait
:
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
dist
.
wait
(
...
...
@@ -614,18 +728,26 @@ def _allgather_buffer(layer_id,
use_calc_stream
=
use_calc_stream
)
core
.
VarBase
(
full_param
.
_slice
(
0
,
param
.
_numel
())).
_share_buffer_to
(
param
)
param
.
value
().
get_tensor
().
_set_dims
(
param
.
shape
)
param
.
fw_storage
.
_clear
()
param
.
fw_storage
=
None
param
.
status
=
"all"
param
.
use_count
+=
1
task_flow
.
full_param
[
param
.
name
]
=
full_param
# parameter converts to cpu
if
convert2cpu
:
p_name
=
param
.
name
param
=
_device2cpu
(
param
)
tmp_var
=
task_flow
.
full_param
.
pop
(
p_name
)
tmp_var
.
_clear
()
task_flow
.
full_param
[
p_name
]
=
param
return
task_flow
@
paddle
.
no_grad
()
def
_create_params_grad
(
layer
,
trainable_params
,
param2buffer_size
,
task_flow
):
for
param
in
trainable_params
[
id
(
layer
)]
:
def
_create_params_grad
(
trainable_params
,
param2buffer_size
,
task_flow
):
for
param
in
trainable_params
:
if
param
.
name
in
task_flow
.
full_grad
.
keys
():
continue
assert
isinstance
(
param2buffer_size
[
param
.
name
],
int
)
...
...
@@ -668,6 +790,23 @@ def _OptimizerWrapper(optimizer, offload, group, update_params_slice):
return
optimizer
def
_device2cpu
(
trans_param
,
convert_dtype
=
False
):
if
convert_dtype
:
trans_param
=
paddle
.
cast
(
trans_param
,
Type
.
fp32
.
value
)
tmp_p
=
trans_param
.
cpu
()
trans_param
.
_clear
()
return
tmp_p
def
_cpu2device
(
param
):
tmp_p
=
param
.
fw_storage
.
cuda
(
DEV_ID
)
param
.
fw_storage
.
_clear
()
if
tmp_p
.
dtype
==
Type
.
fp32
.
value
and
param2dtype
[
param
.
name
]
==
Type
.
fp16
.
value
:
tmp_p
=
paddle
.
cast
(
tmp_p
,
Type
.
fp16
.
value
)
return
tmp_p
def
_current_layer_params
(
layer
):
return
layer
.
parameters
(
include_sublayers
=
False
)
+
list
(
layer
.
extra_parameters
)
if
hasattr
(
...
...
python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
浏览文件 @
46823104
...
...
@@ -30,7 +30,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar
from
paddle.distributed.fleet.meta_parallel.sharding.sharding_utils
import
ShardingScaler
epoch
=
10
batch_size
=
32
paddle
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
base_lr
=
0.1
...
...
@@ -66,10 +65,10 @@ def reader_decorator(linear_size=1000):
def
optimizer_setting
(
model
,
use_pure_fp16
,
opt_group
=
False
):
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
clip_norm
=
1.0
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
optimizer
=
paddle
.
optimizer
.
Momentum
(
parameters
=
[{
"params"
:
model
.
parameters
(
)
}]
if
opt_group
else
model
.
parameters
(
),
"params"
:
list
(
model
.
parameters
()
)
}]
if
opt_group
else
list
(
model
.
parameters
()
),
learning_rate
=
0.001
,
weight_decay
=
0.00001
,
grad_clip
=
clip
,
...
...
@@ -82,6 +81,7 @@ def train_mlp(model,
sharding_stage
,
use_pure_fp16
=
False
,
accumulate_grad
=
False
,
batch_size
=
100
,
opt_group
=
False
,
recompute
=
False
):
group
=
paddle
.
distributed
.
new_group
([
0
,
1
])
...
...
@@ -104,10 +104,14 @@ def train_mlp(model,
optimizer
,
group
=
group
,
buffer_max_size
=
2
**
21
,
accumulate_grads
=
accumulate_grad
)
accumulate_grads
=
batch_size
==
20
)
elif
sharding_stage
==
3
:
model
=
ShardingStage3
(
model
,
optimizer
=
optimizer
,
group
=
group
,
sync_comm
=
recompute
)
model
,
optimizer
=
optimizer
,
group
=
group
,
accumulate_grads
=
batch_size
==
20
,
sync_comm
=
recompute
)
train_reader
=
paddle
.
batch
(
reader_decorator
(),
batch_size
=
batch_size
,
drop_last
=
True
)
...
...
@@ -131,21 +135,22 @@ def train_mlp(model,
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
)
avg_loss
=
paddle
.
mean
(
x
=
loss
.
cast
(
dtype
=
paddle
.
float32
))
if
not
use_pure_fp16
:
avg_loss
.
backward
()
else
:
scaler
.
scale
(
avg_loss
).
backward
()
if
not
accumulate_grad
:
if
not
use_pure_fp16
:
avg_loss
.
backward
()
optimizer
.
step
()
else
:
scaler
.
scale
(
avg_loss
).
backward
()
scaler
.
step
(
optimizer
)
scaler
.
update
()
optimizer
.
clear_grad
()
if
accumulate_grad
:
if
not
use_pure_fp16
:
avg_loss
.
backward
()
optimizer
.
step
()
else
:
scaler
.
scale
(
avg_loss
).
backward
()
scaler
.
step
(
optimizer
)
scaler
.
update
()
optimizer
.
clear_grad
()
...
...
@@ -168,48 +173,50 @@ def test_stage2_stage3():
mlp8
.
set_state_dict
(
state_dict
)
# fp32
stage2_params
=
train_mlp
(
mlp1
,
sharding_stage
=
2
,
use_pure_fp16
=
False
,
opt_group
=
Tru
e
)
mlp1
,
sharding_stage
=
2
,
use_pure_fp16
=
False
,
opt_group
=
Fals
e
)
stage3_params
=
train_mlp
(
mlp2
,
sharding_stage
=
3
,
use_pure_fp16
=
False
,
opt_group
=
True
)
mlp2
,
sharding_stage
=
3
,
use_pure_fp16
=
False
,
opt_group
=
False
)
for
i
in
range
(
len
(
stage2_params
)):
for
j
in
range
(
len
(
stage3_params
)):
if
stage2_params
[
i
].
name
==
stage3_params
[
j
].
name
:
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
()
,
stage3_params
[
j
].
numpy
(),
rtol
=
1e-6
)
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
(),
stage3_params
[
i
].
numpy
(),
rtol
=
1e-6
,
atol
=
1e-6
)
# fp32 accumulate grad
stage
2
_params
=
train_mlp
(
stage
3
_params
=
train_mlp
(
mlp3
,
sharding_stage
=
2
,
sharding_stage
=
3
,
use_pure_fp16
=
False
,
accumulate_grad
=
True
,
opt_group
=
True
)
stage3_params
=
train_mlp
(
stage3_params
_add
=
train_mlp
(
mlp4
,
sharding_stage
=
3
,
use_pure_fp16
=
False
,
accumulate_grad
=
True
,
batch_size
=
20
,
opt_group
=
True
)
for
i
in
range
(
len
(
stage
2
_params
)):
for
j
in
range
(
len
(
stage3_params
)):
if
stage2_params
[
i
].
name
==
stage3_params
[
j
].
name
:
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
()
,
stage3_params
[
j
].
numpy
(),
rtol
=
1e-6
)
for
i
in
range
(
len
(
stage
3
_params
)):
np
.
testing
.
assert_allclose
(
stage3_params
[
i
].
numpy
(),
stage3_params_add
[
i
].
numpy
(),
rtol
=
1e-6
,
atol
=
1e-6
)
# fp16
stage2_params
=
train_mlp
(
mlp5
,
sharding_stage
=
2
,
use_pure_fp16
=
True
,
opt_group
=
False
)
stage3_params
=
train_mlp
(
mlp6
,
sharding_stage
=
3
,
use_pure_fp16
=
True
,
opt_group
=
False
)
for
i
in
range
(
len
(
stage2_params
)):
for
j
in
range
(
len
(
stage3_params
)):
if
stage2_params
[
i
].
name
==
stage3_params
[
j
].
name
:
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
()
,
stage3_params
[
j
].
numpy
(),
rtol
=
1e-6
)
np
.
testing
.
assert_allclose
(
stage2_params
[
i
].
numpy
(),
stage3_params
[
i
].
numpy
(),
rtol
=
1e-4
,
atol
=
1e-4
)
# fp16 recompute
stage3_params
=
train_mlp
(
mlp7
,
sharding_stage
=
3
,
use_pure_fp16
=
True
,
opt_group
=
False
)
...
...
@@ -220,12 +227,8 @@ def test_stage2_stage3():
opt_group
=
False
,
recompute
=
True
)
for
i
in
range
(
len
(
stage3_params
)):
for
j
in
range
(
len
(
stage3_params_re
)):
if
stage3_params
[
i
].
name
==
stage3_params_re
[
j
].
name
:
np
.
testing
.
assert_allclose
(
stage3_params
[
i
].
numpy
(),
stage3_params_re
[
j
].
numpy
(),
rtol
=
1e-6
)
np
.
testing
.
assert_allclose
(
stage3_params
[
i
].
numpy
(),
stage3_params_re
[
i
].
numpy
(),
rtol
=
1e-6
)
return
...
...
python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py
0 → 100644
浏览文件 @
46823104
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
argparse
import
ast
import
time
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.nn
import
Linear
from
paddle.distributed
import
fleet
from
paddle.fluid.dygraph
import
nn
from
paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3
import
ShardingStage3
from
paddle.distributed.fleet.meta_parallel.sharding.sharding_utils
import
ShardingScaler
epoch
=
10
batch_size
=
32
paddle
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
base_lr
=
0.1
momentum_rate
=
0.9
l2_decay
=
1e-4
fleet
.
init
(
is_collective
=
True
)
class
MLP
(
fluid
.
Layer
):
def
__init__
(
self
,
linear_size
=
1000
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
MLP
,
self
).
__init__
()
self
.
_linear1
=
Linear
(
linear_size
,
linear_size
)
self
.
_linear2
=
Linear
(
linear_size
,
linear_size
)
self
.
_linear3
=
Linear
(
linear_size
,
10
)
def
forward
(
self
,
inputs
):
y
=
self
.
_linear1
(
inputs
)
y
=
self
.
_linear2
(
y
)
y
=
self
.
_linear3
(
y
)
return
y
def
reader_decorator
(
linear_size
=
1000
):
def
__reader__
():
for
_
in
range
(
100
):
img
=
np
.
random
.
rand
(
linear_size
).
astype
(
'float32'
)
label
=
np
.
ones
(
1
).
astype
(
'int64'
)
yield
img
,
label
return
__reader__
def
optimizer_setting
(
model
,
use_pure_fp16
,
opt_group
=
False
):
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
clip_norm
=
1.0
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
parameters
=
[{
"params"
:
model
.
parameters
()
}]
if
opt_group
else
model
.
parameters
(),
learning_rate
=
0.001
,
weight_decay
=
0.00001
,
grad_clip
=
clip
,
multi_precision
=
use_pure_fp16
)
return
optimizer
def
train_mlp
(
model
,
use_pure_fp16
=
False
,
accumulate_grad
=
False
,
offload
=
False
,
convert2cpu
=
False
):
group
=
paddle
.
distributed
.
new_group
([
0
,
1
])
optimizer
=
optimizer_setting
(
model
=
model
,
use_pure_fp16
=
use_pure_fp16
)
if
use_pure_fp16
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
,
save_dtype
=
'float32'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
32768
)
scaler
=
ShardingScaler
(
scaler
)
model
=
ShardingStage3
(
model
,
optimizer
=
optimizer
,
group
=
group
,
offload
=
offload
)
train_reader
=
paddle
.
batch
(
reader_decorator
(),
batch_size
=
batch_size
,
drop_last
=
True
)
train_loader
=
paddle
.
io
.
DataLoader
.
from_generator
(
capacity
=
32
,
use_double_buffer
=
True
,
iterable
=
True
,
return_list
=
True
,
use_multiprocess
=
True
)
train_loader
.
set_sample_list_generator
(
train_reader
)
for
eop
in
range
(
epoch
):
model
.
train
()
for
batch_id
,
data
in
enumerate
(
train_loader
()):
img
,
label
=
data
label
.
stop_gradient
=
True
img
.
stop_gradient
=
True
with
paddle
.
amp
.
auto_cast
(
True
,
level
=
'O2'
):
out
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
)
avg_loss
=
paddle
.
mean
(
x
=
loss
.
cast
(
dtype
=
paddle
.
float32
))
if
not
use_pure_fp16
:
avg_loss
.
backward
()
else
:
scaler
.
scale
(
avg_loss
).
backward
()
if
not
accumulate_grad
:
if
not
use_pure_fp16
:
optimizer
.
step
()
else
:
scaler
.
step
(
optimizer
)
scaler
.
update
()
optimizer
.
clear_grad
()
if
accumulate_grad
:
if
not
use_pure_fp16
:
optimizer
.
step
()
else
:
scaler
.
step
(
optimizer
)
scaler
.
update
()
optimizer
.
clear_grad
()
if
not
convert2cpu
:
model
.
get_all_parameters
()
else
:
model
.
get_all_parameters
(
convert2cpu
)
return
model
.
parameters
()
def
test_stage3_offload
():
mlp
,
mlp1
,
mlp2
,
mlp3
,
mlp4
,
mlp5
,
mlp6
=
MLP
(),
MLP
(),
MLP
(),
MLP
(),
MLP
(
),
MLP
(),
MLP
()
state_dict
=
mlp
.
state_dict
()
mlp1
.
set_state_dict
(
state_dict
)
mlp2
.
set_state_dict
(
state_dict
)
mlp3
.
set_state_dict
(
state_dict
)
mlp4
.
set_state_dict
(
state_dict
)
mlp5
.
set_state_dict
(
state_dict
)
mlp6
.
set_state_dict
(
state_dict
)
# fp32 offload
stage3_params
=
train_mlp
(
mlp1
,
use_pure_fp16
=
False
)
stage3_params_offload
=
train_mlp
(
mlp2
,
use_pure_fp16
=
False
,
offload
=
True
)
for
i
in
range
(
len
(
stage3_params
)):
np
.
testing
.
assert_allclose
(
stage3_params
[
i
].
numpy
(),
stage3_params_offload
[
i
].
numpy
(),
rtol
=
1e-6
,
atol
=
1e-8
)
# fp16 offload
stage3_params
=
train_mlp
(
mlp3
,
use_pure_fp16
=
True
)
stage3_params_offload
=
train_mlp
(
mlp4
,
use_pure_fp16
=
True
,
offload
=
True
)
for
i
in
range
(
len
(
stage3_params
)):
np
.
testing
.
assert_allclose
(
stage3_params
[
i
].
numpy
(),
stage3_params_offload
[
i
].
numpy
(),
rtol
=
1e-2
,
atol
=
1e-2
)
# fp32 accumulate grad offload
stage3_params
=
train_mlp
(
mlp5
,
use_pure_fp16
=
False
,
accumulate_grad
=
True
)
stage3_params_offload
=
train_mlp
(
mlp6
,
use_pure_fp16
=
False
,
accumulate_grad
=
True
,
offload
=
True
,
convert2cpu
=
True
)
for
i
in
range
(
len
(
stage3_params
)):
np
.
testing
.
assert_allclose
(
stage3_params
[
i
].
numpy
(),
stage3_params_offload
[
i
].
numpy
(),
rtol
=
1e-6
,
atol
=
1e-8
)
return
if
__name__
==
'__main__'
:
test_stage3_offload
()
python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py
浏览文件 @
46823104
...
...
@@ -23,10 +23,10 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class
TestDygraphShardingStage2
(
TestMultipleGpus
):
# check sharding logic as well as the accuracy with single mode
def
test_dygraph_sharding_
optimizer_
stage2
(
self
):
def
test_dygraph_sharding_stage2
(
self
):
self
.
run_mnist_2gpu
(
'dygraph_sharding_stage2.py'
)
def
test_dygraph_sharding_
optimizer_
stage2_offload
(
self
):
def
test_dygraph_sharding_stage2_offload
(
self
):
self
.
run_mnist_2gpu
(
'dygraph_sharding_stage2_offload.py'
)
...
...
python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py
浏览文件 @
46823104
...
...
@@ -23,9 +23,12 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class
TestDygraphShardingStage3
(
TestMultipleGpus
):
# check sharding logic as well as the accuracy with single mode
def
test_dygraph_sharding_
optimizer_
stage3
(
self
):
def
test_dygraph_sharding_stage3
(
self
):
self
.
run_mnist_2gpu
(
'dygraph_sharding_stage3.py'
)
def
test_dygraph_sharding_stage3_offload
(
self
):
self
.
run_mnist_2gpu
(
'dygraph_sharding_stage3_offload.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录