未验证 提交 801c1723 编写于 作者: S Shuai Zheng 提交者: GitHub

fix file ordering (#1822)

Co-authored-by: NShuai Zheng <shzheng@amazon.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 56c52238
......@@ -12,6 +12,7 @@ import torch
import glob
import math
import os
import re
from collections import OrderedDict
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
......@@ -34,6 +35,19 @@ debug = 0
device = torch.device('cpu')
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [atoi(c) for c in re.split(r'(\d+)', text)]
def get_model_state_file(checkpoint_dir, zero_stage):
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
......@@ -52,7 +66,9 @@ def get_model_state_file(checkpoint_dir, zero_stage):
def get_optim_files(checkpoint_dir):
# XXX: need to test that this simple glob rule works for multi-node setup too
optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt")))
optim_files = sorted(glob.glob(os.path.join(checkpoint_dir,
"*_optim_states.pt")),
key=natural_keys)
if len(optim_files) == 0:
raise FileNotFoundError(
......
import os
import pytest
import deepspeed
from deepspeed.utils.zero_to_fp32 import get_optim_files
@pytest.mark.parametrize('num_checkpoints', [1, 2, 12, 24])
def test_get_optim_files(tmpdir, num_checkpoints):
saved_files = []
for i in range(num_checkpoints):
file_name = "zero_" + str(i) + "_optim_states.pt"
path_name = os.path.join(tmpdir, file_name)
saved_files.append(path_name)
with open(path_name, "w") as f:
f.write(file_name)
loaded_files = get_optim_files(tmpdir)
for lf, sf in zip(loaded_files, saved_files):
assert lf == sf
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册