未验证 提交 2a921069 编写于 作者: S Stas Bekman 提交者: GitHub

[model weights] zero_to_fp32 multiple improvements (#1181)

* add live zero checkpoint to fp32 consolidation version

* some more docs

* zero2 model states uses a different filename

* fix

* make debug mode cli configurable

* copy the script only on node 0 process 0

* validate that we have the right number of files

* revamp _get_zero_param_shapes, instrument with easier debug

* correct assertion

* rename API; add even simpler API

* style

* docs improve

* update the docs

* revert the unpartitioned_params detection and report as it's most likely persistent params
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 f65ff908
......@@ -126,6 +126,9 @@ class DeepSpeedEngine(Module):
# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}
# Set config using config_params for backwards compat
if self.config is None and config_params is not None:
self.config = config_params
......@@ -1264,7 +1267,7 @@ class DeepSpeedEngine(Module):
self.optimizer.step()
# Quantize the updated parameter if there no overflow
# Quantize the updated parameter if there is no overflow
if self.quantizer:
self.quantizer.quantize(
(self.optimizer.fp16_groups
......@@ -1960,12 +1963,35 @@ class DeepSpeedEngine(Module):
return buffer_names
def _get_param_shapes(self):
def _get_zero_param_shapes(self):
"""Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the
optimizer. the names are exactly as in state_dict. The order is absolutely important, since
the saved data is just flattened data with no identifiers and requires reconstruction in the
same order it was saved.
We can't rely on self.module.named_parameters() to get the saved tensors, as some params
will be missing and others unsaved and then it'd be impossible to reconstruct state_dict
from the flattened weights.
optimizer.fp16_groups seems to be the easiest to use as it's in all zeroX versions.
"""
param_shapes = OrderedDict()
for name, param in self.module.named_parameters():
param_shapes[name] = param.ds_shape if hasattr(param,
"ds_shape") else param.shape
# print(f"saving param {name} {param_shapes[name]}")
cnt = 0
numel = 0
for fp16_group in self.optimizer.fp16_groups:
for param in fp16_group:
cnt += 1
numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel()
shape = param.ds_shape if hasattr(param, "ds_shape") else param.shape
if param not in self.param_names:
raise ValueError(f"failed to find optimizer param in named params")
name = self.param_names[param]
param_shapes[name] = shape
# uncomment to debug zero_to_fp32.py problems
# if self.global_rank == 0: print(f"saving param {name} {shape} (numel={shape.numel()})")
# if self.global_rank == 0: print(f"Total saved {numel} numels in {cnt} params")
return param_shapes
def _copy_recovery_script(self, save_path):
......@@ -1981,11 +2007,12 @@ class DeepSpeedEngine(Module):
def _save_zero_checkpoint(self, save_path, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(),
param_shapes=self._get_param_shapes(),
param_shapes=self._get_zero_param_shapes(),
ds_config=self.config,
ds_version=version)
torch.save(zero_sd, zero_checkpoint_name)
self._copy_recovery_script(save_path)
if self.global_rank == 0:
self._copy_recovery_script(save_path)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
def _zero3_consolidated_fp16_state_dict(self):
......
......@@ -960,7 +960,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
self.fp16_groups.append(sub_group)
self.sub_group_to_group_id[i] = j
#These are the list of the partitoned parameters
#These are the list of the partitioned parameters
self.fp16_partitioned_groups.append(
[param.ds_tensor for param in self.fp16_groups[i]])
......@@ -1106,7 +1106,12 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
self.fp16_groups.append(sub_group)
self.sub_group_to_group_id[i] = j
#These are the list of the partitoned parameters
# comment out for zero_to_fp32 debug
# if torch.distributed.get_rank() == 0:
# for param in self.fp16_groups[i]:
# print(f"{debug_param2name_id_shape(param)} {param.ds_shape}")
#These are the list of the partitioned parameters
self.fp16_partitioned_groups.append(
[param.ds_tensor for param in self.fp16_groups[i]])
......@@ -1406,14 +1411,16 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
def persistent_parameters(self):
persistent_params = []
total_persistent_parameters = 0
params_count = 0
for _, param in self.module.named_parameters(recurse=True):
if param.ds_numel < self.persistence_threshold:
params_count += 1
param.ds_persist = True
persistent_params.append(param)
total_persistent_parameters += param.ds_numel
print_rank_0(
f'ZeRO 3: Total persistent parameters: {total_persistent_parameters}',
f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params",
force=False)
return persistent_params
......
......@@ -5,40 +5,42 @@
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
#
# example: python zero_to_fp32.py global_step1 pytorch_model.bin
# example: python zero_to_fp32.py . pytorch_model.bin
import argparse
import torch
import glob
import os
from collections import OrderedDict
import deepspeed
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
import deepspeed
from deepspeed.utils import logger
debug = 0
# load to cpu
device = torch.device('cpu')
def get_model_state_file(checkpoint_dir):
def get_model_state_file(checkpoint_dir, zero_stage):
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# there should be only one file
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
if zero_stage == 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
if not os.path.exists(file):
raise FileNotFoundError(f"can't find '{file}' in directory '{checkpoint_dir}'")
raise FileNotFoundError(f"can't find model states file at '{file}'")
return file
def get_optim_files(checkpoint_dir):
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# XXX: need to test that this simple glob rule works for multi-node setup too
optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt")))
......@@ -50,16 +52,13 @@ def get_optim_files(checkpoint_dir):
def parse_model_state(file):
# load to cpu
device = torch.device('cpu')
state_dict = torch.load(file, map_location=device)
if "buffer_names" not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict["buffer_names"]
if debug:
print(buffer_names)
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers = {
......@@ -70,10 +69,12 @@ def parse_model_state(file):
return buffers
def parse_optim_states(files):
def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dicts.append(torch.load(f))
state_dicts.append(torch.load(f, map_location=device))
if not "zero_stage" in state_dicts[0]['optimizer_state_dict']:
raise ValueError(f"{files[0]} is not a zero checkpoint")
......@@ -81,6 +82,12 @@ def parse_optim_states(files):
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]
param_shapes = state_dicts[0]["param_shapes"]
if world_size != total_files:
raise ValueError(
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
)
# the groups are named differently in each stage
if zero_stage == 2:
fp32_groups_key = "single_partition_of_fp32_groups"
......@@ -109,25 +116,24 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
return partitioned_numel, padding_numel
def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file):
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
"""
Convert zero 2 or 3 checkpoint into a single fp32 consolidated state_dict file that can be
loaded with ``torch.load(file)`` and used for training without DeepSpeed.
Returns fp32 state_dict reconstructed from ds checkpoint
Args:
- ``checkpoint_dir``: path to the deepspeed checkpoint folder
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
"""
print(f"Processing zero checkpoint '{checkpoint_dir}'")
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
model_file = get_model_state_file(checkpoint_dir)
optim_files = get_optim_files(checkpoint_dir)
buffers = parse_model_state(model_file)
zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files)
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
print(
f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
buffers = parse_model_state(model_file)
# Reconstruction protocol:
#
# - for zero2 we just need to concat the partitions back to back and reconsolidate over one huge
......@@ -144,6 +150,16 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file):
if zero_stage == 2:
# XXX: memory usage doubles here (zero2)
full_single_fp32_vector = torch.cat(fp32_flat_groups, 0)
avail_numel = full_single_fp32_vector.numel()
elif zero_stage == 3:
avail_numel = fp32_flat_groups[0].numel() * world_size
if debug:
wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
# not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.")
state_dict = OrderedDict()
......@@ -157,9 +173,12 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file):
# out-of-core computing solution
offset = 0
total_numel = 0
total_params = 0
for name, shape in param_shapes.items():
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
if zero_stage == 2:
if debug:
......@@ -177,7 +196,7 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file):
if debug:
print(
f"{name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
# XXX: memory usage doubles here (zero3)
......@@ -189,26 +208,141 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file):
0).view(shape)
offset += partitioned_numel + partitioned_padding_numel
# the job is done
print(f"Saving fp32 state dict to {output_file} (total_numel={total_numel})")
if zero_stage == 3:
offset *= world_size
# Sanity check
if offset != avail_numel:
raise ValueError(
f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
)
return state_dict
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
via a model hub.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
Returns:
- pytorch ``state_dict``
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
the checkpoint.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# do the training and checkpoint saving
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
model = model.cpu() # move to cpu
model.load_state_dict(state_dict)
# submit to model hub or save the model to share with others
In this example the ``model`` will no longer be useable in the deepspeed context of the same
application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
"""
if tag is None:
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
if not os.path.isdir(ds_checkpoint_dir):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
"""
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
print(f"Saving fp32 state dict to {output_file}")
torch.save(state_dict, output_file)
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
"""
1. Put the provided model to cpu
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
3. Load it into the provided model
Args:
- ``model``: the model object to update
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
Returns:
- ``model`: modified model
Make sure you have plenty of CPU memory available before you call this function. If you don't
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
conveniently placed for you in the checkpoint folder.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
# submit to model hub or save the model to share with others
Note, that once this was run, the ``model`` will no longer be useable in the deepspeed context
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
"""
logger.info(f"Extracting fp32 weights")
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
logger.info(f"Overwriting model with fp32 weights")
model = model.cpu()
model.load_state_dict(state_dict, strict=False)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"checkpoint_dir",
type=str,
help=
"path to the deepspeed checkpoint folder, e.g., path/checkpoint-1/global_step1")
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
parser.add_argument(
"output_file",
type=str,
help=
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-1/pytorch_model.bin)"
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
)
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args()
convert_zero_chkpt_to_fp32_consolid_state_dict(args.checkpoint_dir, args.output_file)
debug = args.debug
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
......@@ -270,8 +270,8 @@ You can use this method to save ZeRO-2 weights as well.
If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage:
``` bash
$ cd /path/to/checkpoints_dir
$ ./zero_to_fp32.py global_step1 pytorch_model.bin
$ cd /path/to/checkpoint_dir
$ ./zero_to_fp32.py . pytorch_model.bin
Processing zero checkpoint at global_step1
Detected checkpoint of type zero stage 3, world_size: 2
Saving fp32 state dict to pytorch_model.bin (total_numel=60506624)
......@@ -281,5 +281,21 @@ The `zero_to_fp32.py` gets created automatically when you save a checkpoint.
Note: currently this script uses 2x memory (general RAM) of the size of the final checkpoint.
Alternatively, if you have plenty of spare CPU memory and instead of getting the file you want your model to be updated to its fp32 weights, you can do the following at the end of the training:
``` python
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
fp32_model = load_state_dict_from_zero_checkpoint(deepspeed.module, checkpoint_dir)
```
Beware, that the model will be good for saving, but no longer good for continuing the training and will require a `deepspeed.initialize()` anew.
If you just want the `state_dict`, you can do:
``` python
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir)
```
Congratulations! You have completed the ZeRO tutorial.
......@@ -10,3 +10,15 @@ Loading Training Checkpoints
Saving Training Checkpoints
---------------------------
.. autofunction:: deepspeed.DeepSpeedEngine.save_checkpoint
ZeRO Checkpoint fp32 Weights Recovery
-------------------------------------
DeepSpeed provides routines for extracting fp32 weights from the saved ZeRO checkpoint's optimizer states.
.. autofunction:: deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint
.. autofunction:: deepspeed.utils.zero_to_fp32.load_state_dict_from_zero_checkpoint
.. autofunction:: deepspeed.utils.zero_to_fp32.convert_zero_checkpoint_to_fp32_state_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册