Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
e80ae088
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
e80ae088
编写于
3月 23, 2023
作者:
O
Olatunji Ruwase
提交者:
GitHub
3月 23, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Empty ZeRO3 partition cache (#3060)
Co-authored-by:
N
Jeff Rasley
<
jerasley@microsoft.com
>
上级
5cdf3593
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
89 addition
and
1 deletion
+89
-1
deepspeed/runtime/engine.py
deepspeed/runtime/engine.py
+10
-0
deepspeed/runtime/zero/parameter_offload.py
deepspeed/runtime/zero/parameter_offload.py
+4
-1
deepspeed/runtime/zero/stage3.py
deepspeed/runtime/zero/stage3.py
+3
-0
docs/code-docs/source/zero3.rst
docs/code-docs/source/zero3.rst
+26
-0
tests/unit/runtime/zero/test_zero.py
tests/unit/runtime/zero/test_zero.py
+46
-0
未找到文件。
deepspeed/runtime/engine.py
浏览文件 @
e80ae088
...
...
@@ -9,6 +9,7 @@ import torch
import
hashlib
from
collections
import
defaultdict
,
OrderedDict
,
deque
from
shutil
import
copyfile
import
gc
from
torch.nn.modules
import
Module
from
torch.nn.parameter
import
Parameter
...
...
@@ -3546,3 +3547,12 @@ class DeepSpeedEngine(Module):
self
.
checkpoint_engine
.
commit
(
tag
)
return
True
def
empty_partition_cache
(
self
):
"""
Release GPU memory consumed by offloaded model parameters.
"""
if
hasattr
(
self
.
optimizer
,
'empty_partition_cache'
):
self
.
optimizer
.
empty_partition_cache
()
gc
.
collect
()
get_accelerator
().
empty_cache
()
deepspeed/runtime/zero/parameter_offload.py
浏览文件 @
e80ae088
...
...
@@ -259,6 +259,9 @@ class DeepSpeedZeRoOffload(object):
return
self
.
param_coordinators
[
training
]
def
empty_partition_cache
(
self
):
self
.
partition_all_parameters
()
def
_convert_to_zero_parameters
(
self
,
ds_config
,
module
,
mpu
):
non_zero_params
=
[
p
for
p
in
module
.
parameters
()
if
not
is_zero_param
(
p
)]
if
non_zero_params
:
...
...
@@ -321,7 +324,7 @@ class DeepSpeedZeRoOffload(object):
if
param
.
ds_numel
+
total_persistent_parameters
>
model_threshold
:
continue
if
param
.
ds_numel
<
param_threshold
:
if
param
.
ds_numel
<
=
param_threshold
:
params_count
+=
1
param
.
ds_persist
=
True
persistent_params
.
append
(
param
)
...
...
deepspeed/runtime/zero/stage3.py
浏览文件 @
e80ae088
...
...
@@ -2467,6 +2467,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if
len
(
self
.
persistent_parameters
)
>
0
:
self
.
persistent_parameters
[
0
].
all_gather
(
self
.
persistent_parameters
)
def
empty_partition_cache
(
self
):
self
.
parameter_offload
.
empty_partition_cache
()
def
_handle_overflow
(
cpu_sum
,
x
,
i
):
import
math
...
...
docs/code-docs/source/zero3.rst
浏览文件 @
e80ae088
...
...
@@ -331,3 +331,29 @@ These routines can be used in a training loop as shown in the following snippet.
[...]
optimizer
.
step
()
GPU
Memory
Management
---------------------
By
default
at
the
end
of
training
with
ZeRO
stage
3
some
parameters
could
remain
unpartitioned
and
use
up
some
gpu
memory
.
This
is
done
on
purpose
as
an
optimization
should
you
resume
training
again
.
If
you
'd like to clear out the cached
parameters that use up gpu memory, you can call ``empty_partition_cache`` method of a DeepSpeed engine.
.. autofunction::deepspeed.DeepSpeedEngine.empty_partition_cache
The following code snippet illustrates this functionality.
.. code-block:: python
with zero.Init():
model = MyLargeModel()
ds_engine, _, _, _ = deepspeed.initialize(model, ...)
for batch in ...:
loss = ds_engine(batch)
ds_engine.backward(batch)
ds_engine.step()
# Free GPU memory consumed by model parameters
ds_engine.empty_partition_cache()
tests/unit/runtime/zero/test_zero.py
浏览文件 @
e80ae088
...
...
@@ -1422,3 +1422,49 @@ class TestZeroOffloadOptim(DistributedTest):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
model
=
model
,
optimizer
=
optimizer
,
config
=
config_dict
)
@
pytest
.
mark
.
parametrize
(
'training'
,
[
True
,
False
])
class
TestZeroPartitionCache
(
DistributedTest
):
world_size
=
1
def
test_training_partition_cache
(
self
,
training
):
hidden_dim
=
10
config_dict
=
{
"train_batch_size"
:
2
,
"fp16"
:
{
"enabled"
:
True
,
"initial_scale_power"
:
8
},
"zero_optimization"
:
{
"stage"
:
3
,
"stage3_param_persistence_threshold"
:
hidden_dim
}
}
if
training
:
config_dict
[
"optimizer"
]
=
{
"type"
:
"Adam"
}
with
deepspeed
.
zero
.
Init
(
config_dict_or_path
=
config_dict
):
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
model
=
model
,
config
=
config_dict
)
dtype
=
torch
.
half
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
6
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
,
dtype
=
dtype
)
for
_
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
if
training
:
model
.
backward
(
loss
)
model
.
step
()
persist_param_size
=
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
ds_persist
])
assert
persist_param_size
>=
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()])
model
.
empty_partition_cache
()
assert
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()])
==
0
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录