Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
a52cbf80
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
a52cbf80
编写于
4月 26, 2022
作者:
J
Jeff Rasley
提交者:
GitHub
4月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[zero-3] add bwd support for list/dict types returned in fwd (#1857)
上级
b4fcd98f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
84 addition
and
2 deletion
+84
-2
deepspeed/runtime/zero/stage3.py
deepspeed/runtime/zero/stage3.py
+20
-2
tests/unit/test_zero.py
tests/unit/test_zero.py
+64
-0
未找到文件。
deepspeed/runtime/zero/stage3.py
浏览文件 @
a52cbf80
...
...
@@ -73,9 +73,14 @@ def move_to_cpu(tensor_list):
tensor
.
data
=
tensor
.
data
.
cpu
()
def
is_builtin_type
(
obj
):
# https://stackoverflow.com/a/17795199
return
obj
.
__class__
.
__module__
==
'__builtin__'
or
obj
.
__class__
.
__module__
==
"builtins"
#apply torch.autograd.Function that calls a backward_function to tensors in output
def
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
outputs
):
if
type
(
outputs
)
is
tuple
:
if
isinstance
(
outputs
,
(
tuple
,
list
))
:
touched_outputs
=
[]
for
output
in
outputs
:
touched_output
=
_apply_to_tensors_only
(
module
,
...
...
@@ -83,10 +88,23 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
backward_function
,
output
)
touched_outputs
.
append
(
touched_output
)
return
tuple
(
touched_outputs
)
return
outputs
.
__class__
(
touched_outputs
)
elif
isinstance
(
outputs
,
dict
):
# apply inplace to avoid recreating dict inherited objects
for
key
in
outputs
.
keys
():
outputs
[
key
]
=
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
outputs
[
key
])
return
outputs
elif
type
(
outputs
)
is
torch
.
Tensor
:
return
functional
.
apply
(
module
,
backward_function
,
outputs
)
else
:
if
not
is_builtin_type
(
outputs
):
logger
.
warning
(
f
"A module has unknown inputs or outputs type (
{
type
(
outputs
)
}
) and the tensors embedded in it cannot be detected. "
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
"output tensors and therefore may not get triggered properly."
)
return
outputs
...
...
tests/unit/test_zero.py
浏览文件 @
a52cbf80
...
...
@@ -1222,3 +1222,67 @@ def test_zero_offload_stage1():
model
.
step
()
_go
(
model
=
model
,
hidden_dim
=
hidden_dim
)
@
pytest
.
mark
.
parametrize
(
'return_type'
,
[
tuple
,
list
,
dict
])
def
test_z3_dict_fwd
(
return_type
):
config_dict
=
{
"train_batch_size"
:
4
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
1e-4
}
},
"fp16"
:
{
"enabled"
:
True
},
"zero_optimization"
:
{
"stage"
:
3
}
}
hidden_dim
=
10
class
MyModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
):
super
(
MyModel
,
self
).
__init__
()
self
.
l1
=
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
self
.
cel
=
torch
.
nn
.
CrossEntropyLoss
()
def
forward
(
self
,
x
,
y
):
x
=
self
.
l1
(
x
)
loss
=
self
.
cel
(
x
,
y
)
if
return_type
==
dict
:
val
=
{
'a'
:
x
,
'loss'
:
loss
,
'b'
:
1
,
'c'
:
None
}
elif
return_type
==
list
:
val
=
[
x
,
loss
]
elif
return_type
==
tuple
:
val
=
(
x
,
loss
)
else
:
raise
NotImplementedError
return
val
@
distributed_test
(
world_size
=
[
1
])
def
_go
(
hidden_dim
):
with
deepspeed
.
zero
.
Init
():
model
=
MyModel
(
hidden_dim
)
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
model
=
model
,
model_parameters
=
model
.
parameters
(),
config
=
config_dict
)
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
torch
.
distributed
.
barrier
()
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
if
return_type
==
dict
:
loss
=
loss
[
'loss'
]
else
:
loss
=
loss
[
1
]
model
.
backward
(
loss
)
model
.
step
()
_go
(
hidden_dim
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录