Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
e801e6d7
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 12 个月
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
e801e6d7
编写于
9月 05, 2023
作者:
A
Alexander Jipa
提交者:
GitHub
9月 05, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
skipping redundant MoE optimizer state loading (#4120)
Co-authored-by:
N
Alexander Jipa
<
azzhipa@amazon.com
>
上级
9894c06a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
26 addition
and
16 deletion
+26
-16
deepspeed/runtime/engine.py
deepspeed/runtime/engine.py
+14
-11
tests/unit/checkpoint/common.py
tests/unit/checkpoint/common.py
+12
-5
未找到文件。
deepspeed/runtime/engine.py
浏览文件 @
e801e6d7
...
...
@@ -2759,26 +2759,29 @@ class DeepSpeedEngine(Module):
self
.
loaded_checkpoint_dp_world_size
=
checkpoint
[
'dp_world_size'
]
optim_checkpoint
=
None
if
load_module_only
:
deepspeed_states
=
[
'module'
]
if
self
.
optimizer
is
not
None
and
self
.
fp16_enabled
():
self
.
optimizer
.
refresh_fp32_params
()
else
:
if
self
.
has_moe_layers
:
largest_group_name
=
groups
.
_get_max_expert_size_name
()
expp_rank
=
groups
.
_get_expert_parallel_rank
(
largest_group_name
)
optim_load_path
=
self
.
_get_optimizer_ckpt_name
(
load_dir
,
tag
,
expp_rank
)
optim_checkpoint
=
self
.
checkpoint_engine
.
load
(
optim_load_path
,
map_location
=
torch
.
device
(
'cpu'
))
else
:
optim_checkpoint
=
checkpoint
has_zero_optimizer_state
=
self
.
zero_optimization
()
or
self
.
bfloat16_enabled
()
if
load_optimizer_states
and
self
.
optimizer
is
not
None
and
not
has_zero_optimizer_state
:
if
self
.
fp16_enabled
():
if
self
.
has_moe_layers
:
largest_group_name
=
groups
.
_get_max_expert_size_name
()
expp_rank
=
groups
.
_get_expert_parallel_rank
(
largest_group_name
)
optim_load_path
=
self
.
_get_optimizer_ckpt_name
(
load_dir
,
tag
,
expp_rank
)
optim_checkpoint
=
self
.
checkpoint_engine
.
load
(
optim_load_path
,
map_location
=
torch
.
device
(
'cpu'
))
else
:
optim_checkpoint
=
checkpoint
if
self
.
fp16_enabled
()
or
self
.
bfloat16_enabled
():
self
.
optimizer
.
load_state_dict
(
optim_checkpoint
[
'optimizer'
],
load_optimizer_states
=
load_optimizer_states
)
else
:
self
.
optimizer
.
load_state_dict
(
optim_checkpoint
[
'optimizer'
])
optim_checkpoint
=
checkpoint
self
.
optimizer
.
load_state_dict
(
optim_checkpoint
[
'optimizer'
])
if
load_lr_scheduler_states
and
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
load_state_dict
(
checkpoint
[
'lr_scheduler'
])
...
...
@@ -2835,7 +2838,7 @@ class DeepSpeedEngine(Module):
client_state
=
{
key
:
value
for
key
,
value
in
checkpoint
.
items
()
if
not
key
in
deepspeed_states
}
if
not
load_optimizer_states
and
not
load_module_only
:
if
optim_checkpoint
is
not
None
:
client_state
[
'optimizer'
]
=
optim_checkpoint
[
'optimizer'
]
return
load_path
,
client_state
...
...
tests/unit/checkpoint/common.py
浏览文件 @
e801e6d7
...
...
@@ -15,6 +15,7 @@ from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
from
deepspeed.runtime.zero.partition_parameters
import
ZeroParamStatus
from
unit.simple_model
import
*
from
unittest.mock
import
MagicMock
,
patch
def
compare_deepspeed_states
(
saved_model
,
loaded_model
):
...
...
@@ -209,11 +210,17 @@ def checkpoint_correctness_verification(config_dict,
loaded_model
=
create_deepspeed_model
(
config_dict
=
config_dict
,
model
=
models
[
1
],
base_optimizer
=
base_optimizers
[
1
])
assert
list
(
trained_model
.
parameters
())[
0
].
dtype
==
list
(
loaded_model
.
parameters
())[
0
].
dtype
loaded_model
.
load_checkpoint
(
save_folder
,
tag
=
save_tag
,
load_optimizer_states
=
load_optimizer_states
,
load_lr_scheduler_states
=
load_lr_scheduler_states
,
load_module_only
=
load_module_only
)
context
=
patch
.
object
(
loaded_model
,
"_get_optimizer_ckpt_name"
,
wraps
=
loaded_model
.
_get_optimizer_ckpt_name
)
if
not
load_optimizer_states
else
MagicMock
()
with
context
as
optim_load_state_dict_mock
:
loaded_model
.
load_checkpoint
(
save_folder
,
tag
=
save_tag
,
load_optimizer_states
=
load_optimizer_states
,
load_lr_scheduler_states
=
load_lr_scheduler_states
,
load_module_only
=
load_module_only
)
if
not
load_optimizer_states
:
# should not attempt to get the file name to load it
optim_load_state_dict_mock
.
assert_not_called
()
compare_model_states
(
trained_model
,
loaded_model
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录