Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
adcfd269
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
adcfd269
编写于
1月 12, 2021
作者:
S
Shaden Smith
提交者:
GitHub
1月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Handle actvitation checkpointing args that are None or non-tensors (#660)
Special thanks to @g-karthik for tracking this issue down.
上级
da5563a9
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
49 addition
and
5 deletion
+49
-5
deepspeed/runtime/activation_checkpointing/checkpointing.py
deepspeed/runtime/activation_checkpointing/checkpointing.py
+22
-2
tests/unit/test_activation_checkpointing.py
tests/unit/test_activation_checkpointing.py
+27
-3
未找到文件。
deepspeed/runtime/activation_checkpointing/checkpointing.py
浏览文件 @
adcfd269
...
...
@@ -373,6 +373,10 @@ class CheckpointFunction(torch.autograd.Function):
inputs
=
[]
for
i
,
item
in
enumerate
(
args
[:
-
1
]):
if
not
torch
.
is_tensor
(
item
):
inputs
.
append
(
item
)
continue
partition_size
=
get_partition_size
(
item
)
partition
=
item
.
detach
().
contiguous
().
view
(
-
1
).
narrow
(
0
,
...
...
@@ -413,7 +417,12 @@ class CheckpointFunction(torch.autograd.Function):
inputs
.
append
(
args
[
-
1
])
#just in case something funky is happening such as reuse of inputs
inputs_cuda
=
[
item
.
to
(
cuda_device
)
for
item
in
args
]
inputs_cuda
=
[]
for
item
in
args
:
if
torch
.
is_tensor
(
item
):
inputs_cuda
.
append
(
item
.
to
(
cuda_device
))
else
:
inputs_cuda
.
append
(
item
)
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
@@ -439,6 +448,10 @@ class CheckpointFunction(torch.autograd.Function):
if
PARTITION_ACTIVATIONS
:
new_args
=
[]
for
i
,
(
arg
,
inp
)
in
enumerate
(
zip
(
args
,
inputs
)):
if
not
torch
.
is_tensor
(
arg
):
new_args
.
append
(
arg
)
continue
size
=
torch
.
tensor
(
arg
.
size
())
arg
.
data
=
inp
.
data
...
...
@@ -573,7 +586,14 @@ class CheckpointFunction(torch.autograd.Function):
timers
.
log
([
'backward'
])
if
SYNCHRONIZE
:
torch
.
cuda
.
synchronize
()
return
(
None
,
)
+
tuple
(
inp
.
grad
for
inp
in
detached_inputs
)
ret_list
=
[
None
]
# first None for ctx
for
inp
in
detached_inputs
:
if
torch
.
is_tensor
(
inp
):
ret_list
.
append
(
inp
.
grad
)
else
:
ret_list
.
append
(
None
)
return
tuple
(
ret_list
)
def
checkpoint
(
function
,
*
args
):
...
...
tests/unit/test_activation_checkpointing.py
浏览文件 @
adcfd269
...
...
@@ -23,7 +23,7 @@ def _compute(module, *inputs, do_checkpoint=False):
sum
(
o
.
sum
()
for
o
in
outputs
if
o
.
requires_grad
).
backward
()
grads
=
[
p
.
grad
for
p
in
module
.
parameters
()]
input_grads
=
[
inp
.
grad
for
inp
in
inputs
]
input_grads
=
[
inp
.
grad
for
inp
in
inputs
if
torch
.
is_tensor
(
inp
)
]
return
{
'outputs'
:
outputs
,
...
...
@@ -32,6 +32,18 @@ def _compute(module, *inputs, do_checkpoint=False):
}
def
_prep_inputs
(
*
inputs
):
_inputs
=
[]
for
inp
in
inputs
:
inp
=
deepcopy
(
inp
)
if
torch
.
is_tensor
(
inp
):
inp
=
inp
.
cuda
()
_inputs
.
append
(
inp
)
return
tuple
(
_inputs
)
# This is distributed because checkpoint() assumes that torch.distributed is initialized.
# torch.distributed is used with activation partitioning, but not for these simple cases.
@
distributed_test
(
world_size
=
1
)
...
...
@@ -43,11 +55,11 @@ def _test_activation_checkpoint(module, *inputs):
module
.
eval
()
module_
=
deepcopy
(
module
)
inputs_
=
tuple
(
deepcopy
(
inp
).
cuda
()
for
inp
in
inputs
)
inputs_
=
_prep_inputs
(
*
inputs
)
base
=
_compute
(
module_
,
*
inputs_
,
do_checkpoint
=
False
)
module_
=
deepcopy
(
module
)
inputs_
=
tuple
(
deepcopy
(
inp
).
cuda
()
for
inp
in
inputs
)
inputs_
=
_prep_inputs
(
*
inputs
)
test
=
_compute
(
module_
,
*
inputs_
,
do_checkpoint
=
True
)
for
group
in
base
.
keys
():
...
...
@@ -155,3 +167,15 @@ def test_ckpt_inputs2_outputs3(mask):
inputs
=
torch
.
rand
(
HIDDEN_DIM
)
inputs
.
requires_grad
=
True
_test_activation_checkpoint
(
module
,
inputs
,
mask
)
class
DropMaskLinear
(
torch
.
nn
.
Linear
):
def
forward
(
self
,
x
,
mask
):
return
super
().
forward
(
x
)
def
test_ckpt_arg_none
():
module
=
DropMaskLinear
(
HIDDEN_DIM
,
HIDDEN_DIM
)
inputs
=
(
torch
.
rand
(
HIDDEN_DIM
),
None
)
inputs
[
0
].
requires_grad
=
True
_test_activation_checkpoint
(
module
,
*
inputs
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录