Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
0e357666
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
0e357666
编写于
4月 26, 2023
作者:
郭叶军
提交者:
GitHub
4月 26, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
stage_1_and_2.py: do gradient scale only for fp16 (#3166)
Co-authored-by:
N
Olatunji Ruwase
<
olruwase@microsoft.com
>
上级
2e99f6ed
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
11 addition
and
6 deletion
+11
-6
deepspeed/runtime/zero/stage_1_and_2.py
deepspeed/runtime/zero/stage_1_and_2.py
+11
-6
未找到文件。
deepspeed/runtime/zero/stage_1_and_2.py
浏览文件 @
0e357666
...
...
@@ -1665,18 +1665,21 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
return
# Step 1:- Calculate gradient norm using fp-16 grads
see_memory_usage
(
'Before norm calculation'
)
scaled_global_grad_norm
=
self
.
scaled_global_norm
()
self
.
_global_grad_norm
=
scaled_global_grad_norm
/
prev_scale
if
self
.
dtype
==
torch
.
float16
:
see_memory_usage
(
'Before norm calculation'
)
scaled_global_grad_norm
=
self
.
scaled_global_norm
()
self
.
_global_grad_norm
=
scaled_global_grad_norm
/
prev_scale
see_memory_usage
(
'After norm before optimizer'
)
see_memory_usage
(
'After norm before optimizer'
)
# Step 2:- run optimizer and upscaling simultaneously
for
i
,
group
in
enumerate
(
self
.
bit16_groups
):
self
.
start_timers
([
OPTIMIZER_GRADIENTS
])
partition_id
=
dist
.
get_rank
(
group
=
self
.
real_dp_process_group
[
i
])
if
self
.
cpu_offload
:
single_grad_partition
=
self
.
single_partition_of_fp32_groups
[
i
].
grad
self
.
unscale_and_clip_grads
([
single_grad_partition
],
scaled_global_grad_norm
)
if
self
.
dtype
==
torch
.
float16
:
self
.
unscale_and_clip_grads
([
single_grad_partition
],
scaled_global_grad_norm
)
self
.
stop_timers
([
OPTIMIZER_GRADIENTS
])
self
.
start_timers
([
OPTIMIZER_STEP
])
self
.
_optimizer_step
(
i
)
...
...
@@ -1715,7 +1718,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self
.
averaged_gradients
[
i
]
=
None
self
.
unscale_and_clip_grads
([
single_grad_partition
],
scaled_global_grad_norm
)
if
self
.
dtype
==
torch
.
float16
:
self
.
unscale_and_clip_grads
([
single_grad_partition
],
scaled_global_grad_norm
)
self
.
stop_timers
([
OPTIMIZER_GRADIENTS
])
# Step 3:- run the optimizer if no offloading
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录