Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
9d79cfd1
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 12 个月
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
9d79cfd1
编写于
8月 14, 2023
作者:
O
Olatunji Ruwase
提交者:
GitHub
8月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Respect memory pinning config (#4131)
* Respect memory pinning config * Bug fix
上级
7a282db8
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
12 deletion
+22
-12
deepspeed/runtime/engine.py
deepspeed/runtime/engine.py
+1
-1
deepspeed/runtime/zero/stage_1_and_2.py
deepspeed/runtime/zero/stage_1_and_2.py
+21
-11
未找到文件。
deepspeed/runtime/engine.py
浏览文件 @
9d79cfd1
...
...
@@ -1452,7 +1452,7 @@ class DeepSpeedEngine(Module):
expert_data_parallel_group
=
self
.
expert_data_parallel_group
if
self
.
has_moe_layers
else
None
,
reduce_scatter
=
self
.
zero_reduce_scatter
(),
overlap_comm
=
overlap_comm
,
cpu_offload
=
self
.
zero_cpu_offload
(),
offload_optimizer_config
=
self
.
zero_offload_optimizer
(),
mpu
=
self
.
mpu
,
postscale_gradients
=
self
.
postscale_gradients
(),
gradient_predivide_factor
=
self
.
gradient_predivide_factor
(),
...
...
deepspeed/runtime/zero/stage_1_and_2.py
浏览文件 @
9d79cfd1
...
...
@@ -120,7 +120,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
expert_data_parallel_group
=
None
,
reduce_scatter
=
True
,
overlap_comm
=
False
,
cpu_offload
=
Fals
e
,
offload_optimizer_config
=
Non
e
,
mpu
=
None
,
clip_grad
=
0.0
,
gradient_accumulation_dtype
=
torch
.
float32
,
...
...
@@ -135,10 +135,17 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
fp16_master_weights_and_gradients
=
False
,
elastic_checkpoint
=
False
):
if
offload_optimizer_config
is
not
None
and
offload_optimizer_config
.
device
!=
OffloadDeviceEnum
.
none
:
self
.
cpu_offload
=
True
self
.
cpu_offload_pin_memory
=
offload_optimizer_config
.
pin_memory
else
:
self
.
cpu_offload
=
False
self
.
cpu_offload_pin_memory
=
False
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"Reduce bucket size
{
reduce_bucket_size
}
"
)
logger
.
info
(
f
"Allgather bucket size
{
allgather_bucket_size
}
"
)
logger
.
info
(
f
"CPU Offload:
{
cpu_offload
}
"
)
logger
.
info
(
f
"CPU Offload:
{
self
.
cpu_offload
}
"
)
logger
.
info
(
f
'Round robin gradient partitioning:
{
round_robin_gradients
}
'
)
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
...
...
@@ -153,7 +160,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master grad and unflat master weight never exist. TODO: a way to save out unflat master?
if
not
get_accelerator
().
is_available
():
raise
SystemError
(
"
Cannot use fp16 without accelerator
."
)
raise
SystemError
(
"
Accelerator is not detected, cannot perform low precision training (e.g., fp16, bf16)
."
)
self
.
optimizer
=
init_optimizer
# Use torch (un)flatten ops
...
...
@@ -170,9 +177,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self
.
overlap_comm
=
overlap_comm
self
.
cpu_offload
=
cpu_offload
self
.
deepspeed_adam_offload
=
cpu_offload
self
.
deepspeed_adam_offload
=
self
.
cpu_offload
self
.
device
=
get_accelerator
().
current_device_name
()
if
not
self
.
cpu_offload
else
'cpu'
...
...
@@ -195,7 +200,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self
.
is_gradient_accumulation_boundary
=
True
# CPU-Offload requires contiguous gradients
self
.
contiguous_gradients
=
contiguous_gradients
or
cpu_offload
self
.
contiguous_gradients
=
contiguous_gradients
or
self
.
cpu_offload
self
.
has_moe_layers
=
has_moe_layers
if
self
.
has_moe_layers
:
...
...
@@ -440,8 +445,12 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self
.
norm_for_param_grads
=
{}
self
.
local_overflow
=
False
self
.
grad_position
=
{}
self
.
temp_grad_buffer_for_cpu_offload
=
get_accelerator
().
pin_memory
(
torch
.
zeros
(
largest_param_numel
,
device
=
self
.
device
,
dtype
=
self
.
dtype
))
self
.
temp_grad_buffer_for_cpu_offload
=
torch
.
zeros
(
largest_param_numel
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
if
self
.
cpu_offload_pin_memory
:
self
.
temp_grad_buffer_for_cpu_offload
=
get_accelerator
().
pin_memory
(
self
.
temp_grad_buffer_for_cpu_offload
)
self
.
temp_grad_buffer_for_gpu_offload
=
torch
.
zeros
(
largest_param_numel
,
device
=
get_accelerator
().
current_device_name
(),
dtype
=
self
.
dtype
)
...
...
@@ -631,7 +640,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
dtype
=
self
.
single_partition_of_fp32_groups
[
i
].
dtype
,
device
=
self
.
device
)
self
.
single_partition_of_fp32_groups
[
i
].
grad
=
get_accelerator
().
pin_memory
(
single_grad_partition
)
if
self
.
cpu_offload
else
single_grad_partition
single_grad_partition
)
if
self
.
cpu_offload
_pin_memory
else
single_grad_partition
# Initialize the optimizer states with the flattened fp32 partition.
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
...
...
@@ -1101,7 +1110,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
#buffer for storing gradients for this parameter in CPU
def
buffer_to_accumulate_to_in_cpu
():
if
not
self
.
fp16_master_weights_and_gradients
:
return
get_accelerator
().
pin_memory
(
torch
.
zeros
(
param
.
numel
(),
dtype
=
param
.
dtype
,
device
=
self
.
device
))
buffer
=
torch
.
zeros
(
param
.
numel
(),
dtype
=
param
.
dtype
,
device
=
self
.
device
)
return
get_accelerator
().
pin_memory
(
buffer
)
if
self
.
cpu_offload_pin_memory
else
buffer
else
:
return
self
.
single_partition_of_fp32_groups
[
i
].
grad
.
view
(
-
1
).
narrow
(
0
,
dest_offset
,
num_elements
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录