Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
7567c76c
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
7567c76c
编写于
11月 13, 2021
作者:
O
Olatunji Ruwase
提交者:
GitHub
11月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update offload parameter names (#1536)
Co-authored-by:
N
Jeff Rasley
<
jerasley@microsoft.com
>
上级
9caa74e5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
18 addition
and
13 deletion
+18
-13
deepspeed/runtime/zero/stage2.py
deepspeed/runtime/zero/stage2.py
+3
-1
deepspeed/runtime/zero/stage3.py
deepspeed/runtime/zero/stage3.py
+5
-2
docs/code-docs/source/memory.rst
docs/code-docs/source/memory.rst
+10
-10
未找到文件。
deepspeed/runtime/zero/stage2.py
浏览文件 @
7567c76c
...
...
@@ -15,6 +15,7 @@ import collections
from
deepspeed.runtime.fp16.loss_scaler
import
LossScaler
,
DynamicLossScaler
from
deepspeed.runtime.utils
import
bwc_tensor_model_parallel_rank
,
get_global_norm
,
see_memory_usage
,
is_model_parallel_parameter
from
deepspeed.runtime.zero.config
import
ZERO_OPTIMIZATION_GRADIENTS
from
deepspeed.runtime.zero.offload_constants
import
OFFLOAD_CPU_DEVICE
,
OFFLOAD_OPTIMIZER
,
OFFLOAD_OPTIMIZER_DEVICE
from
deepspeed.ops.adam
import
DeepSpeedCPUAdam
from
deepspeed.ops.op_builder
import
UtilsBuilder
from
deepspeed.utils
import
logger
...
...
@@ -2242,7 +2243,8 @@ def estimate_zero2_model_states_mem_needs_all_cold(total_params,
"""
def
format_options
(
cpu_offload
):
enabled
=
[]
enabled
.
append
(
f
"cpu_offload=
{
1
if
cpu_offload
else
0
}
"
)
device
=
f
'
{
OFFLOAD_CPU_DEVICE
:
4
}
'
if
cpu_offload
else
"none"
enabled
.
append
(
f
"
{
OFFLOAD_OPTIMIZER
}
=
{
device
}
"
)
return
", "
.
join
(
enabled
)
nodes_str
=
"nodes"
if
num_nodes
>
1
else
"node"
...
...
deepspeed/runtime/zero/stage3.py
浏览文件 @
7567c76c
...
...
@@ -3431,8 +3431,11 @@ def estimate_zero3_model_states_mem_needs_all_cold(total_params,
"""
def
format_options
(
cpu_offload
,
cpu_offload_params
,
zero_init
):
enabled
=
[]
enabled
.
append
(
f
"cpu_offload=
{
1
if
cpu_offload
else
0
}
"
)
enabled
.
append
(
f
"cpu_offload_params=
{
1
if
cpu_offload_params
else
0
}
"
)
padded_cpu_str
=
f
'
{
OFFLOAD_CPU_DEVICE
:
4
}
'
param_device
=
padded_cpu_str
if
cpu_offload_params
else
"none"
enabled
.
append
(
f
"
{
OFFLOAD_PARAM
}
=
{
param_device
}
"
)
optimizer_device
=
padded_cpu_str
if
cpu_offload
else
"none"
enabled
.
append
(
f
"
{
OFFLOAD_OPTIMIZER
}
=
{
optimizer_device
}
"
)
enabled
.
append
(
f
"zero_init=
{
1
if
zero_init
else
0
}
"
)
return
", "
.
join
(
enabled
)
...
...
docs/code-docs/source/memory.rst
浏览文件 @
7567c76c
...
...
@@ -128,19 +128,19 @@ The big question is how big of a model you can fit on the hardware you have? Or
*
ZeRO
-
2
:
-
``
"
cpu_offload"
:
true
``:
2
*
params
-
``
"
offload_optimizer"
:
{
"device"
:
"cpu"
}
``:
2
*
params
Example
:
a
40
GB
GPU
can
fit
~
11
B
param
model
(
regardless
of
how
many
GPUs
are
used
).
Here
the
model
is
loaded
in
``
fp16
``
so
just
the
model
weights
take
about
22
GB
and
the
remaining
18
GB
are
used
by
other
components
.
You
can
barely
fit
a
very
small
batch
size
in
this
scenario
.
-
``
"
cpu_offload"
:
false
``:
4
params
+
16
params
/
(
total
number
of
gpus
)
-
``
"
offload_optimizer"
:
{
"device"
:
"none"
}``:
4
*
params
+
16
*
params
/
(
total
number
of
gpus
)
*
ZeRO
-
3
:
``
largest_layer_memory
=
4
*
largest_layer_params
``
-
GPU
memory
needed
to
gather
the
largest
layer
on
a
single
GPU
.
2
bytes
fp16
params
are
gathered
and
2
bytes
fp16
grads
are
computed
(
total
4
x
).
The
optimizer
states
and
fp32
parameters
are
updated
in
partitioned
form
and
copied
to
fp16
params
in
partitioned
form
.
This
happens
during
the
optimizer
step
.
After
that
the
fp16
params
are
sufficient
.
-
case
1
:
``
"
cpu_offload"
:
false
,
"cpu_offload_params"
:
false
``
-
largest_layer_memory
+
18
*
params
/
total
number
of
gpus
across
all
nodes
-
case
2
:
``
"
cpu_offload"
:
true
,
"cpu_offload_params"
:
true
``-
largest_layer_memory
.
The
main
limit
here
is
general
RAM
.
-
case
3
:
``
"
cpu_offload"
:
true
,
"cpu_offload_params"
:
false
``-
largest_layer_memory
+
2
*
params
/
total
number
of
gpus
across
all
nodes
-
case
1
:
``
"
offload_param"
:
{
"device"
:
"none"
},
"offload_optimizer"
:
{
"device"
:
"none"
}
``
-
largest_layer_memory
+
18
*
params
/
total
number
of
gpus
across
all
nodes
-
case
2
:
``
"
offload_param"
:
{
"device"
:
"cpu"
},
"offload_optimizer"
:
{
"device"
:
"cpu"
}
``-
largest_layer_memory
.
The
main
limit
here
is
general
RAM
.
-
case
3
:
``
"
offload_param"
:
{
"device"
:
"none"
},
"offload_optimizer"
:
{
"device"
:
"cpu"
}
``-
largest_layer_memory
+
2
*
params
/
total
number
of
gpus
across
all
nodes
Example
:
...
...
@@ -194,11 +194,11 @@ In the following calculations we will use:
* ZeRO-2:
- ``"
cpu_offload": false
``:
- ``"
offload_optimizer": {"device": "none"}
``:
params * 4 * n_gpus * additional_buffer_factor - this is the memory needed only at the beginning to initialize the model on CPU memory
- ``"
cpu_offload": true
``:
- ``"
offload_optimizer": {"device": "cpu"}
``:
params * max(4 * n_gpus, 16) * additional_buffer_factor
...
...
@@ -208,7 +208,7 @@ In the following calculations we will use:
gpus_factor = n_gpus / total_gpus
- case 1: ``"
cpu_offload": false
``:
- case 1: ``"
offload_param": {"device": "none"}, "offload_optimizer": {"device": "none"}
``:
Without ``zero.Init``:
...
...
@@ -222,7 +222,7 @@ In the following calculations we will use:
assuming Pytorch is deallocating the memory once the tensors are moved to the GPU by ZeRO.Init
- case 2: ``"
cpu_offload": true, cpu_offload_params true
``:
- case 2: ``"
offload_param": {"device": "cpu"}, "offload_optimizer": {"device": "cpu"}
``:
Without ``zero.Init``:
...
...
@@ -232,7 +232,7 @@ In the following calculations we will use:
params * 18 * gpus_factor * additional_buffer_factor
- case 3: ``"
cpu_offload": true, cpu_offload_params false
``:
- case 3: ``"
offload_param": {"device": "none"}, "offload_optimizer": {"device": "cpu"}
``:
Without ``zero.Init``:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录