Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
376818ef
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
376818ef
编写于
7月 15, 2020
作者:
J
Jeff Rasley
提交者:
GitHub
7月 15, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Empty grad fix (#291)
* empty grad fix * add unit tests for empty grad
上级
607814fe
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
101 addition
and
9 deletion
+101
-9
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+11
-1
tests/unit/simple_model.py
tests/unit/simple_model.py
+8
-3
tests/unit/test_fp16.py
tests/unit/test_fp16.py
+82
-5
未找到文件。
deepspeed/pt/deepspeed_light.py
浏览文件 @
376818ef
...
...
@@ -979,7 +979,17 @@ class DeepSpeedLight(Module):
def
buffered_allreduce_fallback
(
self
,
grads
=
None
,
elements_per_buffer
=
500000000
):
grads
=
[]
for
param_name
,
param
in
self
.
module
.
named_parameters
():
if
param
.
grad
is
not
None
:
if
param
.
grad
is
None
:
# In cases where there is an imbalance of empty grads across
# ranks we must create empty grads, this will ensure that every
# rank is reducing the same size. In some cases it may make
# sense in the future to support the ability to average not
# w.r.t. world size but with a different value.
grads
.
append
(
torch
.
zeros
(
param
.
size
(),
dtype
=
param
.
dtype
,
device
=
param
.
device
))
else
:
grad_data
=
param
.
grad
.
data
if
self
.
sparse_gradients_enabled
(
)
and
param_name
in
self
.
csr_tensor_module_names
:
...
...
tests/unit/simple_model.py
浏览文件 @
376818ef
...
...
@@ -5,16 +5,21 @@ import torch
class
SimpleModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
empty_grad
=
False
):
def
__init__
(
self
,
hidden_dim
,
empty_grad
=
False
,
rank
=
0
):
super
(
SimpleModel
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
if
empty_grad
:
self
.
l
ayers2
=
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)]
)
self
.
l
inear2
=
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
self
.
cross_entropy_loss
=
torch
.
nn
.
CrossEntropyLoss
()
self
.
rank
=
rank
self
.
empty_grad
=
empty_grad
def
forward
(
self
,
x
,
y
):
hidden_dim
=
x
hidden_dim
=
self
.
linear
(
hidden_dim
)
if
self
.
rank
==
0
and
self
.
empty_grad
:
hidden_dim
=
self
.
linear
(
hidden_dim
)
+
self
.
linear2
(
hidden_dim
)
else
:
hidden_dim
=
self
.
linear
(
hidden_dim
)
return
self
.
cross_entropy_loss
(
hidden_dim
,
y
)
...
...
tests/unit/test_fp16.py
浏览文件 @
376818ef
...
...
@@ -33,9 +33,10 @@ def test_lamb_fp32_grad_clip(tmpdir):
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
device
=
model
.
device
,
dtype
=
torch
.
float
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
]
.
float
()
,
batch
[
1
])
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
step
()
...
...
@@ -81,7 +82,7 @@ def test_lamb_fp16_basic(tmpdir):
def
test_lamb_fp16_empty_grad
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
1
,
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Lamb"
,
...
...
@@ -97,9 +98,9 @@ def test_lamb_fp16_empty_grad(tmpdir):
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
True
)
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
True
,
rank
=
args
.
local_rank
)
@
distributed_test
(
world_size
=
[
1
])
@
distributed_test
(
world_size
=
[
2
])
def
_test_lamb_fp16_empty_grad
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
...
...
@@ -116,6 +117,44 @@ def test_lamb_fp16_empty_grad(tmpdir):
_test_lamb_fp16_empty_grad
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
def
test_adam_fp32_empty_grad
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
}
},
"gradient_clipping"
:
1.0
,
"fp16"
:
{
"enabled"
:
False
}
}
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
True
,
rank
=
args
.
local_rank
)
@
distributed_test
(
world_size
=
[
2
])
def
_test_adam_fp32_empty_grad
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
())
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
,
dtype
=
torch
.
float
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
step
()
_test_adam_fp32_empty_grad
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
def
test_adamw_fp16_basic
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
1
,
...
...
@@ -495,3 +534,41 @@ def test_adam_amp_o2(tmpdir):
model
.
step
()
_test_adam_amp_o2
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
def
test_adam_amp_o2_empty_grad
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
}
},
"gradient_clipping"
:
1.0
,
"amp"
:
{
"enabled"
:
True
,
"opt_level"
:
"O2"
}
}
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
,
rank
=
args
.
local_rank
)
@
distributed_test
(
world_size
=
[
2
])
def
_test_adam_amp_o2_empty_grad
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
())
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
step
()
_test_adam_amp_o2_empty_grad
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录