Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
607814fe
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
607814fe
编写于
7月 15, 2020
作者:
O
Olatunji Ruwase
提交者:
GitHub
7月 15, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix bug in fp32 optimizer state loading (#289)
上级
7ccc9daf
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
54 addition
and
13 deletion
+54
-13
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+6
-2
tests/unit/simple_model.py
tests/unit/simple_model.py
+2
-2
tests/unit/test_checkpointing.py
tests/unit/test_checkpointing.py
+46
-9
未找到文件。
deepspeed/pt/deepspeed_light.py
浏览文件 @
607814fe
...
...
@@ -1140,8 +1140,12 @@ class DeepSpeedLight(Module):
self
.
load_module_state_dict
(
state_dict
=
checkpoint
[
'module'
],
strict
=
load_module_strict
)
if
not
self
.
zero_optimization
():
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
],
load_optimizer_states
=
load_optimizer_states
)
if
self
.
fp16_enabled
():
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
],
load_optimizer_states
=
load_optimizer_states
)
else
:
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
if
load_lr_scheduler_states
and
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
load_state_dict
(
checkpoint
[
'lr_scheduler'
])
...
...
tests/unit/simple_model.py
浏览文件 @
607814fe
...
...
@@ -41,9 +41,9 @@ class SimpleOptimizer(torch.optim.Optimizer):
return
loss
def
random_dataloader
(
model
,
total_samples
,
hidden_dim
,
device
):
def
random_dataloader
(
model
,
total_samples
,
hidden_dim
,
device
,
dtype
=
torch
.
half
):
batch_size
=
model
.
train_micro_batch_size_per_gpu
()
train_data
=
torch
.
randn
(
total_samples
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
half
)
train_data
=
torch
.
randn
(
total_samples
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
train_label
=
torch
.
empty
(
total_samples
,
dtype
=
torch
.
long
,
device
=
device
).
random_
(
hidden_dim
)
...
...
tests/unit/test_checkpointing.py
浏览文件 @
607814fe
...
...
@@ -47,14 +47,18 @@ def compare_model_states(saved_model, loaded_model):
for
params0
,
params1
in
zip
(
saved_model
.
optimizer
.
fp32_groups
,
loaded_model
.
optimizer
.
fp32_groups
):
for
p0
,
p1
in
zip
(
params0
,
params1
):
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"FP32 model states
{
p0
}
is not equal to
{
p1
}
"
elif
isinstance
(
saved_model
.
optimizer
,
torch
.
optim
.
Optimizer
):
pass
else
:
assert
False
,
'Unexpected Optimizer Type'
assert
False
,
f
'Unexpected Optimizer Type:
{
saved_model
.
optimizer
}
'
def
compare_optimizer_states
(
saved_model
,
loaded_model
,
hidden_dim
,
fp16
=
True
):
saved_optimizer
=
saved_model
.
optimizer
.
optimizer
if
fp16
else
saved_model
.
optimizer
loaded_optimizer
=
loaded_model
.
optimizer
.
optimizer
if
fp16
else
loaded_model
.
optimizer
def
compare_optimizer_states
(
saved_model
,
loaded_model
,
hidden_dim
):
for
state0
,
state1
in
zip
(
saved_model
.
optimizer
.
optimizer
.
state
.
values
(),
loaded_model
.
optimizer
.
optimizer
.
state
.
values
()):
for
state0
,
state1
in
zip
(
saved_optimizer
.
state
.
values
(),
loaded_optimizer
.
state
.
values
()):
for
s0
,
s1
in
zip
(
state0
.
values
(),
state1
.
values
()):
if
isinstance
(
s0
,
torch
.
Tensor
)
and
isinstance
(
s1
,
torch
.
Tensor
):
assert
torch
.
equal
(
s0
,
s1
)
...
...
@@ -90,15 +94,17 @@ def checkpoint_correctness_verification(args,
hidden_dim
,
tmpdir
,
load_optimizer_states
=
False
,
load_lr_scheduler_states
=
False
):
load_lr_scheduler_states
=
False
,
fp16
=
True
):
dtype
=
torch
.
half
if
fp16
else
torch
.
float32
ds_model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
())
data_loader
=
random_dataloader
(
model
=
ds_model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
ds_model
.
device
)
device
=
ds_model
.
device
,
dtype
=
dtype
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
ds_model
(
batch
[
0
],
batch
[
1
])
ds_model
.
backward
(
loss
)
...
...
@@ -123,7 +129,7 @@ def checkpoint_correctness_verification(args,
compare_model_states
(
trained_model
,
loaded_model
)
if
load_optimizer_states
:
compare_optimizer_states
(
trained_model
,
loaded_model
,
hidden_dim
)
compare_optimizer_states
(
trained_model
,
loaded_model
,
hidden_dim
,
fp16
)
if
load_lr_scheduler_states
:
compare_lr_scheduler_states
(
trained_model
,
loaded_model
)
...
...
@@ -420,3 +426,34 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage):
hidden_dim
=
hidden_dim
,
load_optimizer_states
=
False
,
load_lr_scheduler_states
=
False
)
def
test_checkpoint_fp32_optimizer
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
,
"betas"
:
[
0.8
,
0.999
],
"eps"
:
1e-8
,
"weight_decay"
:
3e-7
}
},
"fp16"
:
{
"enabled"
:
False
}
}
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_fp32_optimizer
(
args
,
model
,
hidden_dim
):
checkpoint_correctness_verification
(
args
,
model
,
hidden_dim
,
tmpdir
,
fp16
=
False
)
_test_checkpoint_fp32_optimizer
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录