未验证 提交 103884ae 编写于 作者: E Eugene Cheah 提交者: GitHub

Update zero_to_fp32.py (#3936)

Since deepspeed 1 checkpoint structure is identical to deepspeed 2 (AFAIK), we should just change the version check and add support accordingly
上级 aef6c65c
......@@ -5,7 +5,7 @@
# DeepSpeed Team
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
......@@ -63,7 +63,7 @@ def get_model_state_file(checkpoint_dir, zero_stage):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# there should be only one file
if zero_stage == 2:
if zero_stage <= 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
......@@ -164,14 +164,14 @@ def parse_optim_states(files, ds_checkpoint_dir):
)
# the groups are named differently in each stage
if zero_stage == 2:
if zero_stage <= 2:
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
elif zero_stage == 3:
fp32_groups_key = FP32_FLAT_GROUPS
else:
raise ValueError(f"unknown zero stage {zero_stage}")
if zero_stage == 2:
if zero_stage <= 2:
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one
......@@ -206,7 +206,7 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
zero_model_states = parse_model_states(model_files)
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
if zero_stage == 2:
if zero_stage <= 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册