未验证 提交 6df15873 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Load z3 checkpoints for inference (#4171)

* Load z3 checkpoints for inference

* PR feedback

* Fix API bugs

* Fix typo
上级 b5453990
......@@ -85,6 +85,7 @@ from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_lt
from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayerTokenDrop
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from .pipe.module import PipelineModule
from .utils import get_ma_status
......@@ -871,6 +872,12 @@ class DeepSpeedEngine(Module):
return (model_dtype, grad_accum_dtype)
def _optimizer_has_ckpt_event_prologue(self):
return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_prologue')
def _optimizer_has_ckpt_event_epilogue(self):
return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue')
def _configure_lr_scheduler(self, client_lr_scheduler):
# First check for scheduler in json configuration
lr_scheduler = self._scheduler_from_config(self.optimizer)
......@@ -2511,14 +2518,23 @@ class DeepSpeedEngine(Module):
state_dict.update(expert_state_dict)
moe_layer_id += 1
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None):
module_state_dict = checkpoint['module']
if custom_load_fn:
custom_load_fn(src=module_state_dict, dst=self.module)
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False):
if fetch_z3_params:
params_to_fetch = [
p for p in self.module.parameters()
if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
]
else:
self.module.load_state_dict(
module_state_dict, # TODO
strict=strict)
params_to_fetch = []
with deepspeed.zero.GatheredParameters(params_to_fetch, modifier_rank=0):
module_state_dict = checkpoint['module']
if custom_load_fn:
custom_load_fn(src=module_state_dict, dst=self.module)
else:
self.module.load_state_dict(
module_state_dict, # TODO
strict=strict)
if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None:
saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS]
......@@ -2649,7 +2665,7 @@ class DeepSpeedEngine(Module):
)
return None, None
if self.zero_optimization_partition_weights():
if self._optimizer_has_ckpt_event_prologue():
# Prepare for checkpoint load by ensuring all parameters are partitioned
self.optimizer.checkpoint_event_prologue()
......@@ -2661,13 +2677,14 @@ class DeepSpeedEngine(Module):
load_module_only=load_module_only,
custom_load_fn=custom_load_fn)
load_zero_checkpoint = self.zero_optimization() or self.bfloat16_enabled()
if load_zero_checkpoint and load_path is not None:
load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization()
or self.bfloat16_enabled())
if load_zero_checkpoint:
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
if not success:
self.optimizer._restore_from_bit16_weights()
if self.zero_optimization_partition_weights():
if self._optimizer_has_ckpt_event_epilogue():
self.optimizer.checkpoint_event_epilogue()
return load_path, client_states
......@@ -2694,6 +2711,11 @@ class DeepSpeedEngine(Module):
if checkpoint is None:
return None, None
fetch_z3_params = False
if self.zero_optimization_partition_weights() and not load_optimizer_states:
checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir)
fetch_z3_params = True
if is_pipe_parallel:
# Pipeline parallelism uses this to load its own checkpoint files.
self._curr_ckpt_path = os.path.join(load_dir, tag)
......@@ -2714,7 +2736,8 @@ class DeepSpeedEngine(Module):
if not self.load_universal_checkpoint():
self.load_module_state_dict(checkpoint=checkpoint,
strict=load_module_strict,
custom_load_fn=custom_load_fn)
custom_load_fn=custom_load_fn,
fetch_z3_params=fetch_z3_params)
self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
......@@ -2933,8 +2956,8 @@ class DeepSpeedEngine(Module):
process with rank 0.
"""
if self.zero_optimization_partition_weights():
# Prepare for checkpoint save by ensuring all parameters are partitioned
if self._optimizer_has_ckpt_event_prologue():
# Custom preparation for checkpoint save, if applicable
self.optimizer.checkpoint_event_prologue()
rank = self.local_rank if self.use_node_local_storage() else self.global_rank
......@@ -2979,7 +3002,7 @@ class DeepSpeedEngine(Module):
self._create_zero_checkpoint_files(save_dir, tag)
self._save_zero_checkpoint(save_dir, tag)
if self.zero_optimization_partition_weights():
if self._optimizer_has_ckpt_event_epilogue():
self.optimizer.checkpoint_event_epilogue()
# Save latest checkpoint tag
......@@ -3381,13 +3404,15 @@ class DeepSpeedEngine(Module):
get_layer_state_dict(child, prefix + name + ".")
# Prepare for checkpoint save by ensuring all parameters are partitioned
self.optimizer.checkpoint_event_prologue()
if self._optimizer_has_ckpt_event_prologue():
self.optimizer.checkpoint_event_prologue()
see_memory_usage("before get_layer_state_dict", force=False)
get_layer_state_dict(self.module, prefix="")
see_memory_usage("after get_layer_state_dict", force=False)
self.optimizer.checkpoint_event_epilogue()
if self._optimizer_has_ckpt_event_epilogue():
self.optimizer.checkpoint_event_epilogue()
return state_dict
......
......@@ -1282,7 +1282,7 @@ class PipelineEngine(DeepSpeedEngine):
exclude_frozen_params=exclude_frozen_parameters)
return None
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None):
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False):
"""Override hack to instead use a directory path.
This is important because pipeline models checkpoint by layer instead of rank.
......
......@@ -211,6 +211,7 @@ class DeepSpeedZeRoOffload(object):
max_live_parameters=1000000000,
param_persistence_threshold=100000,
model_persistence_threshold=sys.maxsize,
dp_process_group=None,
offload_param_config=None,
mpu=None,
zero_param_parallel_group=None,
......@@ -225,6 +226,7 @@ class DeepSpeedZeRoOffload(object):
self.module = module
self.timers = timers
self.dtype = list(module.parameters())[0].dtype
self.dp_process_group = dp_process_group
self.offload_device = None
self.offload_param_pin_memory = False
self.zero_param_parallel_group = zero_param_parallel_group
......
......@@ -181,6 +181,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
max_live_parameters=max_live_parameters,
param_persistence_threshold=param_persistence_threshold,
model_persistence_threshold=model_persistence_threshold,
dp_process_group=dp_process_group,
offload_param_config=offload_param_config,
mpu=mpu,
zero_param_parallel_group=zero_param_parallel_group,
......@@ -217,7 +218,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.reduce_scatter = reduce_scatter
self.dp_process_group = dp_process_group
self.dp_process_group = self.parameter_offload.dp_process_group
self.all2all_process_group = all2all_process_group
......@@ -386,6 +387,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
max_live_parameters,
param_persistence_threshold,
model_persistence_threshold,
dp_process_group,
offload_param_config,
mpu,
zero_param_parallel_group,
......@@ -401,6 +403,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
max_live_parameters=max_live_parameters,
param_persistence_threshold=param_persistence_threshold,
model_persistence_threshold=model_persistence_threshold,
dp_process_group=dp_process_group,
offload_param_config=offload_param_config,
mpu=mpu,
zero_param_parallel_group=zero_param_parallel_group,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册