Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
ba67bd9a
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
ba67bd9a
编写于
7月 29, 2022
作者:
N
Nicholas Cilfone
提交者:
GitHub
7月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added retain_graph as a kwarg to the main engine backward function (#1149)
Co-authored-by:
N
Jeff Rasley
<
jerasley@microsoft.com
>
上级
556f0051
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
9 addition
and
6 deletion
+9
-6
deepspeed/runtime/engine.py
deepspeed/runtime/engine.py
+9
-6
未找到文件。
deepspeed/runtime/engine.py
浏览文件 @
ba67bd9a
...
...
@@ -1711,11 +1711,14 @@ class DeepSpeedEngine(Module):
loss
,
allreduce_gradients
=
True
,
release_loss
=
False
,
retain_graph
=
False
,
scale_wrt_gas
=
True
):
r
"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
allreduce_gradients: is deprecated, ignored, and will soon be removed'
retain_graph: bool, default: false
forward on user defined choice of retain_graph
"""
see_memory_usage
(
"Engine before backward"
,
force
=
self
.
memory_breakdown
())
...
...
@@ -1751,9 +1754,9 @@ class DeepSpeedEngine(Module):
self
.
_start_timers
(
self
.
engine_timers
.
backward_inner_timers
)
if
self
.
zero_optimization
():
self
.
optimizer
.
is_gradient_accumulation_boundary
=
(
self
.
is_gradient_accumulation_boundary
()
)
self
.
optimizer
.
backward
(
loss
)
self
.
optimizer
.
is_gradient_accumulation_boundary
=
self
.
is_gradient_accumulation_boundary
(
)
self
.
optimizer
.
backward
(
loss
,
retain_graph
=
retain_graph
)
elif
self
.
amp_enabled
():
# AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
...
...
@@ -1761,19 +1764,19 @@ class DeepSpeedEngine(Module):
with
amp
.
scale_loss
(
loss
,
self
.
optimizer
,
delay_unscale
=
delay_unscale
)
as
scaled_loss
:
scaled_loss
.
backward
()
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
elif
self
.
fp16_enabled
():
if
self
.
eigenvalue_enabled
():
self
.
optimizer
.
backward
(
loss
,
create_graph
=
True
,
retain_graph
=
True
)
else
:
self
.
optimizer
.
backward
(
loss
)
self
.
optimizer
.
backward
(
loss
,
retain_graph
=
retain_graph
)
elif
self
.
bfloat16_enabled
():
self
.
optimizer
.
backward
(
loss
)
else
:
if
self
.
eigenvalue_enabled
():
loss
.
backward
(
create_graph
=
True
,
retain_graph
=
True
)
else
:
loss
.
backward
()
loss
.
backward
(
retain_graph
=
retain_graph
)
self
.
_stop_timers
(
self
.
engine_timers
.
backward_inner_timers
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录