未验证 提交 2f626978 编写于 作者: S Shaden Smith 提交者: GitHub

Pipeline warnings and checkpoint portability (#588)

* Switch from deprecated allreduce interface.

* Make pipeline checkpoint files portable.
上级 e8b126d9
......@@ -33,6 +33,7 @@ from deepspeed.utils import logger, log_dist
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from .pipe.module import PipelineModule
from .utils import ensure_directory_exists
from ..ops.op_builder import UtilsBuilder
from ..ops.adam import DeepSpeedCPUAdam
......@@ -1355,6 +1356,10 @@ class DeepSpeedEngine(Module):
logger.info(f'rank: {self.global_rank} loading checkpoint: {load_path}')
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
if isinstance(self.module, PipelineModule):
# Pipeline parallelism uses this to load its own checkpoint files.
self._curr_ckpt_path = os.path.join(load_dir, tag)
self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict)
if not self.zero_optimization():
......@@ -1522,8 +1527,8 @@ class DeepSpeedEngine(Module):
save_path = self._get_ckpt_name(save_dir, tag)
# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns self._curr_save_path.
self._curr_save_path = os.path.dirname(save_path)
# then instead just returns None.
self._curr_ckpt_path = os.path.join(save_dir, tag)
state = {
'module':
......
......@@ -52,6 +52,9 @@ class PipelineEngine(DeepSpeedEngine):
super().__init__(*super_args, **super_kwargs)
assert isinstance(self.module, PipelineModule), "model must base PipelineModule"
# We schedule the all-reduces, so disable it in super().backward()
self.enable_backward_allreduce = False
# pipeline step for logging
self.log_batch_step_id = -1
......@@ -546,7 +549,7 @@ class PipelineEngine(DeepSpeedEngine):
# The last stage just runs backward on the loss using DeepSpeed's typical
# mechanisms.
if self.is_last_stage():
super().backward(self.loss, allreduce_gradients=False)
super().backward(self.loss)
self.mem_status('AFTER BWD')
return
......@@ -1100,31 +1103,31 @@ class PipelineEngine(DeepSpeedEngine):
is ``save_state_dict()``.
Returns:
str: The directory path where the checkpoint was saved.
None
"""
assert isinstance(self.module, PipelineModule)
assert self._curr_save_path is not None, \
assert self._curr_ckpt_path is not None, \
"PipelineEngine expects module_state_dict() to be called from save_checkpoint()"
self.module.save_state_dict(self._curr_save_path)
return self._curr_save_path
self.module.save_state_dict(self._curr_ckpt_path)
return None
def load_module_state_dict(self, state_dict, strict=True):
"""Override hack to instead use a directory path.
This is important because pipeline models checkpoint by layer instead of rank.
If ``state_dict`` is not a ``str``, we revert to ``super()`` expecting a ``dict``.
If ``state_dict`` is not ``None`` or a ``str``, we revert to ``super()`` expecting a ``dict``.
Args:
state_dict (str): Path to the directory for checkpoint.
state_dict (str, None): unused
strict (bool, optional): Strict state loading. Defaults to True.
"""
if not isinstance(state_dict, str):
if (state_dict is not None) and (not isinstance(state_dict, str)):
super().load_module_state_dict(state_dict, strict)
return
self.module.load_state_dir(state_dict, strict=strict)
self.module.load_state_dir(load_dir=self._curr_ckpt_path, strict=strict)
# A map of PipeInstruction types to methods. Each method will be executed with the
# kwargs provided to the PipeInstruction from the scheduler.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册