Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
f6903190
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 12 个月
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
f6903190
编写于
8月 24, 2023
作者:
J
Joe Mayer
提交者:
GitHub
8月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify Gradient Attribute Names (#4214)
* name changes * formatting changes
上级
9647ea79
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
35 addition
and
32 deletion
+35
-32
deepspeed/runtime/zero/stage_1_and_2.py
deepspeed/runtime/zero/stage_1_and_2.py
+35
-32
未找到文件。
deepspeed/runtime/zero/stage_1_and_2.py
浏览文件 @
f6903190
...
...
@@ -286,9 +286,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
else
:
self
.
use_separate_grad_accum
=
False
if
self
.
use_separate_grad_accum
and
not
self
.
partition_gradients
:
self
.
use_grad_accum_
for_reduction
=
True
self
.
use_grad_accum_
attribute
=
True
else
:
self
.
use_grad_accum_
for_reduction
=
False
self
.
use_grad_accum_
attribute
=
False
self
.
round_robin_bit16_groups
=
[]
self
.
round_robin_bit16_indices
=
[]
...
...
@@ -828,7 +828,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def
overlapping_partition_gradients_reduce_epilogue
(
self
):
self
.
independent_gradient_partition_epilogue
()
def
update_separate_grad_accum
(
self
):
def
fill_grad_accum_attribute
(
self
):
for
group
in
self
.
bit16_groups
:
for
param
in
group
:
if
param
.
grad
is
not
None
:
...
...
@@ -839,20 +839,18 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
param
.
grad
.
to
(
self
.
gradient_accumulation_dtype
).
view
(
param
.
grad_accum
.
shape
))
param
.
grad
=
None
def
set_grad_accum_pointer
(
self
):
for
group
in
self
.
bit16_groups
:
for
param
in
group
:
param
.
grad_accum
=
param
.
grad
def
get_gradient_for_reduction
(
self
,
param
):
if
self
.
use_grad_accum_
for_reduction
:
if
self
.
use_grad_accum_
attribute
:
return
param
.
grad_accum
.
to
(
self
.
dtype
)
if
param
.
grad_accum
is
not
None
else
None
else
:
return
param
.
grad
def
get_param_gradient_attribute
(
self
,
param
):
return
param
.
grad_accum
if
self
.
use_grad_accum_attribute
else
param
.
grad
# Clear the tensor the reduction gradient attribute is pointing to
def
clear_grad_
reduc_pointer
(
self
,
param
):
if
self
.
use_grad_accum_
for_reduction
:
def
clear_grad_
attribute
(
self
,
param
):
if
self
.
use_grad_accum_
attribute
:
param
.
grad_accum
=
None
else
:
param
.
grad
=
None
...
...
@@ -1086,7 +1084,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
current_offset
+=
num_elements
def
update_overflow_tracker_for_param_grad
(
self
,
param
):
if
param
.
grad_accum
is
not
None
and
self
.
_has_inf_or_nan
(
param
.
grad_accum
.
data
):
grad_accum
=
self
.
get_param_gradient_attribute
(
param
)
if
grad_accum
is
not
None
and
self
.
_has_inf_or_nan
(
grad_accum
.
data
):
self
.
local_overflow
=
True
def
_get_offload_gradient_dict
(
self
):
...
...
@@ -1117,22 +1116,24 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
#accumulate gradients into param.grad_accum or parts of it that belongs to this partition
def
accumulate_gradients
():
grad_accum
=
self
.
get_param_gradient_attribute
(
param
)
if
not
self
.
fp16_master_weights_and_gradients
:
dest_buffer
.
copy_
(
self
.
accumulated_grads_in_cpu
[
param_id
].
view
(
-
1
),
non_blocking
=
True
)
param
.
grad_accum
.
data
.
view
(
-
1
).
add_
(
dest_buffer
)
grad_accum
.
data
.
view
(
-
1
).
add_
(
dest_buffer
)
else
:
dest_buffer
.
narrow
(
0
,
source_offset
,
num_elements
).
copy_
(
self
.
accumulated_grads_in_cpu
[
param_id
].
view
(
-
1
),
non_blocking
=
True
)
param
.
grad_accum
.
data
.
view
(
-
1
).
narrow
(
0
,
source_offset
,
num_elements
).
add_
(
dest_buffer
.
narrow
(
0
,
source_offset
,
num_elements
))
grad_accum
.
data
.
view
(
-
1
).
narrow
(
0
,
source_offset
,
num_elements
).
add_
(
dest_buffer
.
narrow
(
0
,
source_offset
,
num_elements
))
#move accumulated gradients back to CPU
def
copy_gradients_to_cpu
():
grad_accum
=
self
.
get_param_gradient_attribute
(
param
)
if
not
self
.
fp16_master_weights_and_gradients
:
self
.
accumulated_grads_in_cpu
[
param_id
].
data
.
copy_
(
param
.
grad_accum
.
data
.
view
(
-
1
),
non_blocking
=
True
)
self
.
accumulated_grads_in_cpu
[
param_id
].
data
.
copy_
(
grad_accum
.
data
.
view
(
-
1
),
non_blocking
=
True
)
else
:
self
.
accumulated_grads_in_cpu
[
param_id
].
data
.
copy_
(
param
.
grad_accum
.
data
.
view
(
-
1
).
narrow
(
self
.
accumulated_grads_in_cpu
[
param_id
].
data
.
copy_
(
grad_accum
.
data
.
view
(
-
1
).
narrow
(
0
,
source_offset
,
num_elements
),
non_blocking
=
True
)
...
...
@@ -1148,8 +1149,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def
set_norm_for_param_grad
(
self
,
param
):
param_id
=
self
.
get_param_id
(
param
)
grad_accum
=
self
.
get_param_gradient_attribute
(
param
)
accumulated_grad
=
self
.
accumulated_grads_in_cpu
[
param_id
]
if
self
.
gradient_accumulation_steps
>
1
else
param
.
grad_accum
param_id
]
if
self
.
gradient_accumulation_steps
>
1
else
grad_accum
[
i
,
source_offset
,
dest_offset
,
num_elements
]
=
self
.
grad_position
[
param_id
]
...
...
@@ -1160,10 +1162,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def
set_norm_for_param_grad_in_gpu
(
self
,
param
):
param_id
=
self
.
get_param_id
(
param
)
if
param
.
grad_accum
is
None
:
grad_accum
=
self
.
get_param_gradient_attribute
(
param
)
if
grad_accum
is
None
:
accumulated_grad
=
param
.
grad
else
:
accumulated_grad
=
param
.
grad_accum
accumulated_grad
=
grad_accum
[
i
,
source_offset
,
dest_offset
,
num_elements
]
=
self
.
grad_position
[
param_id
]
...
...
@@ -1179,10 +1182,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
dest_tensor
=
self
.
single_partition_of_fp32_groups
[
i
].
grad
.
view
(
-
1
).
narrow
(
0
,
dest_offset
,
num_elements
)
if
param
.
grad_accum
is
None
:
src_tensor
=
param
.
grad
.
view
(
-
1
).
narrow
(
0
,
source_offset
,
num_elements
)
grad_accum
=
self
.
get_param_gradient_attribute
(
param
)
if
grad_accum
is
None
:
src_tensor
=
grad_accum
.
view
(
-
1
).
narrow
(
0
,
source_offset
,
num_elements
)
else
:
src_tensor
=
param
.
grad_accum
.
view
(
-
1
).
narrow
(
0
,
source_offset
,
num_elements
)
src_tensor
=
grad_accum
.
view
(
-
1
).
narrow
(
0
,
source_offset
,
num_elements
)
if
not
self
.
fp16_master_weights_and_gradients
:
src_tensor
=
src_tensor
.
float
()
...
...
@@ -1314,7 +1318,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self
.
previous_reduced_grads
=
[]
self
.
previous_reduced_grads
.
append
(
param
)
else
:
self
.
clear_grad_
reduc_pointer
(
param
)
self
.
clear_grad_
attribute
(
param
)
elif
self
.
contiguous_gradients
:
self
.
copy_grads_in_partition
(
param
)
else
:
# zero stage 1 - partition only optimizer state
...
...
@@ -1425,7 +1429,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def
_clear_previous_reduced_grads
(
self
):
if
self
.
previous_reduced_grads
is
not
None
:
for
param
in
self
.
previous_reduced_grads
:
self
.
clear_grad_
reduc_pointer
(
param
)
self
.
clear_grad_
attribute
(
param
)
self
.
previous_reduced_grads
=
None
# if rank is specified do a reduction instead of an allreduce
...
...
@@ -1605,10 +1609,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
current_size
=
0
for
i
,
tensor
in
enumerate
(
tensor_list
):
if
tensor
.
grad_accum
is
None
:
tensor
.
grad_accum
=
torch
.
zeros_like
(
tensor
,
dtype
=
dtype
)
grad_accum
=
self
.
get_param_gradient_attribute
(
tensor
)
if
grad_accum
is
None
:
grad_accum
=
torch
.
zeros_like
(
tensor
,
dtype
=
dtype
)
tensor
=
tensor
.
grad_accum
tensor
=
grad_accum
num_elements
=
tensor
.
numel
()
tensor_offset
=
0
...
...
@@ -1953,10 +1958,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self
.
loss_scaler
.
backward
(
loss
.
float
(),
retain_graph
=
retain_graph
)
# Only for Stage 1, Mode 2
if
self
.
use_grad_accum_for_reduction
:
self
.
update_separate_grad_accum
()
else
:
self
.
set_grad_accum_pointer
()
if
self
.
use_grad_accum_attribute
:
self
.
fill_grad_accum_attribute
()
def
check_overflow
(
self
,
partition_gradients
=
True
):
self
.
_check_overflow
(
partition_gradients
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录