From 801c1723450320ca589cc9b6620b9ee5618a5f38 Mon Sep 17 00:00:00 2001 From: Shuai Zheng Date: Tue, 19 Apr 2022 15:56:00 -0700 Subject: [PATCH] fix file ordering (#1822) Co-authored-by: Shuai Zheng Co-authored-by: Jeff Rasley --- deepspeed/utils/zero_to_fp32.py | 18 +++++++++++++++++- tests/unit/test_get_optim_files.py | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_get_optim_files.py diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 7c229518..8e8b0fd1 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -12,6 +12,7 @@ import torch import glob import math import os +import re from collections import OrderedDict # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with @@ -34,6 +35,19 @@ debug = 0 device = torch.device('cpu') +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + def get_model_state_file(checkpoint_dir, zero_stage): if not os.path.isdir(checkpoint_dir): raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") @@ -52,7 +66,9 @@ def get_model_state_file(checkpoint_dir, zero_stage): def get_optim_files(checkpoint_dir): # XXX: need to test that this simple glob rule works for multi-node setup too - optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt"))) + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) if len(optim_files) == 0: raise FileNotFoundError( diff --git a/tests/unit/test_get_optim_files.py b/tests/unit/test_get_optim_files.py new file mode 100644 index 00000000..68d046bf --- /dev/null +++ b/tests/unit/test_get_optim_files.py @@ -0,0 +1,18 @@ +import os +import pytest +import deepspeed +from deepspeed.utils.zero_to_fp32 import get_optim_files + + +@pytest.mark.parametrize('num_checkpoints', [1, 2, 12, 24]) +def test_get_optim_files(tmpdir, num_checkpoints): + saved_files = [] + for i in range(num_checkpoints): + file_name = "zero_" + str(i) + "_optim_states.pt" + path_name = os.path.join(tmpdir, file_name) + saved_files.append(path_name) + with open(path_name, "w") as f: + f.write(file_name) + loaded_files = get_optim_files(tmpdir) + for lf, sf in zip(loaded_files, saved_files): + assert lf == sf -- GitLab