未验证 提交 7f90ef4b 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Multiple zero stage 3 related fixes (#3886)

* Option to override module apply

* Removing early partitioning in override

* Unit tests

* Cleanup

* Adapt unit test to succeed

* Handle missed params

* Add accelerate

* Code cleanup

* Add doc

* Add doc

* Add doc
上级 7f26bb6a
......@@ -1041,23 +1041,22 @@ class DeepSpeedEngine(Module):
def _configure_distributed_model(self, model):
self._set_client_model(model)
is_zero3_model = self.zero_optimization_partition_weights() and any(
is_zero_init_model = self.zero_optimization_partition_weights() and any(
[hasattr(param, "ds_id") for param in self.module.parameters()])
if self.fp16_enabled():
if is_zero3_model:
if is_zero_init_model:
self.__check_params(self.module, torch.half)
self.module.half()
elif self.bfloat16_enabled():
if is_zero3_model:
if is_zero_init_model:
self.__check_params(self.module, torch.bfloat16)
self.module.bfloat16()
else:
self.__check_params(self.module, torch.float)
# zero.Init() handles device placement of model
if not self.dont_change_device:
if not (self.dont_change_device or is_zero_init_model):
self.module.to(self.device)
# MoE related initialization
......@@ -1097,7 +1096,7 @@ class DeepSpeedEngine(Module):
self.expert_parallel_group = groups._get_expert_parallel_group_dict()
self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict()
if not self.amp_enabled():
if not (self.amp_enabled() or is_zero_init_model):
self._broadcast_model()
# check if parameters are duplicated in optimizer param_groups
......
......@@ -38,7 +38,8 @@ ZeRO optimization should be enabled as:
"zero_hpz_partition_size": 1,
"zero_quantized_weights": [true|false],
"zero_quantized_gradients": [true|false],
"memory_efficient_linear": [true|false]
"memory_efficient_linear": [true|false],
"override_module_apply": [true|false],
}
}
"""
......@@ -269,11 +270,17 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
mics_shard_size: int = Field(-1, new_param="mics_shard_size")
mics_hierarchical_params_gather: bool = False
memory_efficient_linear: bool = True
"""
Use memory efficient linear implementation, for Stage 3.
"""
override_module_apply: bool = True
"""
Override nn.Module apply function, for Stage 3.
"""
# Validators
@validator("overlap_comm")
def overlap_comm_valid(cls, field_value, values):
......
......@@ -34,7 +34,6 @@ from deepspeed.utils.debug import (debug_param2name_id_shape, debug_param2name_i
from deepspeed.accelerator import get_accelerator
from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
param_count = 0
partitioned_param_data_shape = [0]
zero_init_context = 0
top_level_context = None
......@@ -217,12 +216,14 @@ class ZeroParamStatus(Enum):
INFLIGHT = 3
_orig_torch_tensor = torch.tensor
_orig_torch_empty = torch.empty
_orig_torch_zeros = torch.zeros
_orig_torch_ones = torch.ones
_orig_torch_full = torch.full
_orig_torch_arange = torch.arange
_orig_torch_eye = torch.eye
_orig_torch_randn = torch.randn
def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable:
......@@ -288,6 +289,8 @@ empty_buffers = {}
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object):
num_module_parameters = 0
num_module_elements = 0
def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtype=None):
self.mem_efficient_linear = mem_efficient_linear
......@@ -324,7 +327,10 @@ class InsertPostInitMethodToModuleSubClasses(object):
top_level_context = None
if dist.get_rank() == 0:
logger.info("finished initializing model with %.2fB parameters", param_count / 1e9)
billion_elems = InsertPostInitMethodToModuleSubClasses.num_module_elements / 1e9
num_params = InsertPostInitMethodToModuleSubClasses.num_module_parameters
logger.info(
f"finished initializing model - num_params = {num_params}, num_elems = {billion_elems:.2f}B")
# Now that we cleaned up the metaclass injection, raise the exception.
if exc_type is not None:
......@@ -381,14 +387,16 @@ class InsertPostInitMethodToModuleSubClasses(object):
3. broadcasts root rank's parameters to the other ranks
4. re-partitions the parameters
"""
if not all(is_zero_param(p) for p in module_to_apply_fn_to.parameters(recurse=False)):
raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, "
f"were zero params, is it possible that the parameters were "
f"overwritten after they were initialized? "
f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ")
# TODO Delay error checking for dangling partitioned parameters to post module init
# raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, "
# f"were zero params, is it possible that the parameters were "
# f"overwritten after they were initialized? "
# f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ")
params_to_apply_fn_to: Iterable[Parameter] = list(
sorted(module_to_apply_fn_to.parameters(recurse=False), key=lambda p: p.ds_id))
sorted([p for p in module_to_apply_fn_to.parameters(recurse=False) if is_zero_param(p)],
key=lambda p: p.ds_id))
for param in params_to_apply_fn_to:
param.all_gather()
......@@ -464,7 +472,8 @@ class InsertPostInitMethodToModuleSubClasses(object):
# Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)
if Init.override_module_apply:
torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)
self._add_tensor_creation_wrappers()
......@@ -489,7 +498,8 @@ class InsertPostInitMethodToModuleSubClasses(object):
# putting methods back the way we found them
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
if Init.override_module_apply:
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
self._remove_tensor_creation_wrappers()
......@@ -497,21 +507,25 @@ class InsertPostInitMethodToModuleSubClasses(object):
def _add_tensor_creation_wrappers(self):
torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
torch.tensor = zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, self.dtype)
torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype)
torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype)
torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype)
torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype)
torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype)
torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, self.dtype)
torch.randn = zero_wrapper_for_fp_tensor_constructor(_orig_torch_randn, self.dtype)
def _remove_tensor_creation_wrappers(self):
torch.Tensor.__new__ = torch.Tensor.__old_new__
torch.tensor = _orig_torch_tensor
torch.empty = _orig_torch_empty
torch.zeros = _orig_torch_zeros
torch.ones = _orig_torch_ones
torch.full = _orig_torch_full
torch.arange = _orig_torch_arange
torch.eye = _orig_torch_eye
torch.randn = _orig_torch_randn
def shutdown_init_context():
......@@ -687,6 +701,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
num_persisted_parameters = 0
num_persisted_elements = 0
apply_param_persistence = False
override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply")
def __init__(self,
module=None,
......@@ -845,9 +860,12 @@ class Init(InsertPostInitMethodToModuleSubClasses):
self.quantizer_module = CUDAQuantizer()
print_rank_0(f'Using quantizer: {self.quantizer_module.__class__.__name__}', force=True)
if _ds_config is not None and _ds_config.zero_config.offload_param is not None:
remote_device = _ds_config.zero_config.offload_param.device
pin_memory = _ds_config.zero_config.offload_param.pin_memory
if _ds_config is not None:
Init.override_module_apply = _ds_config.zero_config.override_module_apply
if _ds_config.zero_config.offload_param is not None:
remote_device = _ds_config.zero_config.offload_param.device
pin_memory = _ds_config.zero_config.offload_param.pin_memory
self._validate_remote_device(remote_device, _ds_config)
......@@ -877,12 +895,21 @@ class Init(InsertPostInitMethodToModuleSubClasses):
Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold
Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.num_partitions
def _zero_init_param(self, param):
self._convert_to_deepspeed_param(param)
if dist.get_world_group() == self.get_dp_process_group():
dist.broadcast(param, 0, self.get_dp_process_group())
else:
dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0), self.get_dp_process_group())
param.partition()
def _convert_to_zero_parameters(self, param_list):
for param in param_list:
if is_zero_param(param):
continue
self._convert_to_deepspeed_param(param)
param.partition()
param.data = param.data.to(self.local_device)
self._zero_init_param(param)
def _validate_remote_device(self, remote_device, ds_config):
if ds_config is not None:
......@@ -904,28 +931,19 @@ class Init(InsertPostInitMethodToModuleSubClasses):
print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
see_memory_usage(f"Before converting and partitioning params in {module.__class__.__name__}", force=False)
global param_count
for name, param in module.named_parameters(recurse=False):
param_count += param.numel()
print_rank_0(f'Analyzing param {name} in {module.__class__.__name__}', force=False)
InsertPostInitMethodToModuleSubClasses.num_module_parameters += 1
InsertPostInitMethodToModuleSubClasses.num_module_elements += param.numel()
if not is_zero_param(param):
self._convert_to_deepspeed_param(param)
if not get_accelerator().on_accelerator(param):
param.data = param.data.to(self.local_device)
self._zero_init_param(param)
print_rank_0(
f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}")
if get_accelerator().on_accelerator(param):
if dist.get_world_group() == self.get_dp_process_group():
dist.broadcast(param, 0, self.get_dp_process_group())
else:
dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0),
self.get_dp_process_group())
else:
if dist.get_rank() == 0:
logger.warn(f"param `{name}` in {module.__class__.__name__} "
f"not on GPU so was not broadcasted from rank 0")
param.partition()
see_memory_usage(
f"Param count {param_count}. After converting and partitioning params in {module.__class__.__name__}",
f"Param count {InsertPostInitMethodToModuleSubClasses.num_module_elements}. After converting and partitioning params in {module.__class__.__name__}",
force=False)
def _convert_to_deepspeed_param(self, param):
......@@ -1342,7 +1360,6 @@ class Init(InsertPostInitMethodToModuleSubClasses):
tensor_size = self._aligned_size(param)
partition_size = tensor_size // self.num_partitions
if param.ds_tensor is None:
final_location = None
if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor(
......
......@@ -309,6 +309,17 @@ DeepSpeed can automatically detect the following external parameter scenarios:
.. autofunction:: deepspeed.zero.unregister_external_parameter
.. `Module.apply <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=module+apply#torch.nn.Module.apply>`_
Overriding Module.apply
===============================
A convenient mechanism for customizing model initialization is `Module.apply <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=module+apply#torch.nn.Module.apply>`_.
With ZeRO stage 3, ``Module.apply`` implementations must account for parameter partitioning by ``zero.Init`` during model initialization. The default behavior of ZeRO stage 3 is to automatically
handle this issue by overriding ``Module.apply`` to ensure that parameters are gathered before access by ``Module.apply``. The benefit of this approach is development convenience, since
users are saved the burden of manual parameter coordination in ``Module.apply``. However, the downside is slow model initialization, since all the model parameters (e.g., billions) are gathered
even though the common usage of ``Module.apply`` is to customize a few parameters. Developers can disable this default behavior by setting the ``override_module_apply`` configuration knob to ``False``,
for faster model initialization at the cost of manually handling partitioned parameters in their ``Module.apply`` implementations.
Memory-Centric Tiling
---------------------
......
accelerate
clang-format==16.0.2
coverage
docutils<0.18
......
......@@ -7,6 +7,9 @@ import torch
from unit.common import DistributedTest
from transformers import VisionEncoderDecoderModel
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
......@@ -44,3 +47,26 @@ class TestShutdownInNestingInit(DistributedTest):
# ensure that zero3 processed the parameter
assert hasattr(model2.weight, "ds_id")
deepspeed_engine2, *_ = deepspeed.initialize(model=model2, config_params=ds_config)
class TestNestedParallelInit(DistributedTest):
world_size = 1
# Testing a model with composed and nested zero.Inits, with 3 zero.Init contexts, 1 parent and 2 children.
# The skeleton of the model is like so
#
# class VisionEncoderDecoderModel(...)::
# def __init__(self):
# encoder = AutoModel.from_config(config.encoder)
# decoder = AutoModelForCausalLM.from_config(config.decoder)
#
# And the user calls like below:
# VisionEncoderDecoderModel.from_pretrained(...)
# which calls this constructor inside zero.Init
def test_nested_parallel_init(self):
ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))
dschf = HfDeepSpeedConfig(ds_config) # keep this object alive
model = VisionEncoderDecoderModel.from_pretrained(
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
assert all([hasattr(p, 'ds_id') for p in model.parameters()])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册