未验证 提交 8d3b42c2 编写于 作者: B Bing Xie 提交者: GitHub

Bing/formatting correction (#2764)

* modify engine.py for formatting

* commit formatting changes on engine.py
上级 258d2831
......@@ -2009,11 +2009,14 @@ class DeepSpeedEngine(Module):
return loss
def is_gradient_accumulation_boundary(self):
"""Query whether the current micro-batch is at the boundary of
"""
Query whether the current micro-batch is at the boundary of
gradient accumulation, and thus will trigger gradient reductions and
an optimizer step.
Returns:
bool: if the current step is a gradient accumulation boundary.
"""
if self._is_gradient_accumulation_boundary is None:
return (self.micro_steps + 1) % \
......@@ -2022,7 +2025,8 @@ class DeepSpeedEngine(Module):
return self._is_gradient_accumulation_boundary
def set_gradient_accumulation_boundary(self, is_boundary):
"""Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional
"""
Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional
feature and should be used with care. The state should be set before to the intended
value before each forward/backward. The final fordward/backward should have the
boundary state set to True. This style allows client code to only call engine.step() once after all
......@@ -2714,7 +2718,9 @@ class DeepSpeedEngine(Module):
load_lr_scheduler_states=True,
load_module_only=False,
custom_load_fn=None):
"""Load training checkpoint
"""
Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
......@@ -2723,14 +2729,17 @@ class DeepSpeedEngine(Module):
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.
custom_load_fn: Optional. Custom model load function.
Returns:
A tuple of ``load_path`` and ``client_state``.
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
*``client_state``: State dictionary used for loading required training states in the client code.
Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and
``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine
before ``load_checkpoint()``.
"""
if tag is None:
......@@ -3062,7 +3071,8 @@ class DeepSpeedEngine(Module):
logger.warning(msg)
def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True):
r"""Save training checkpoint
"""Save training checkpoint
Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
......@@ -3073,6 +3083,7 @@ class DeepSpeedEngine(Module):
because each process needs to save its master weights and scheduler+optimizer states. This
method will hang waiting to synchronize with other processes if it's called just for the
process with rank 0.
"""
if self.zero_optimization_partition_weights():
# Prepare for checkpoint save by ensuring all parameters are partitioned
......@@ -3467,17 +3478,23 @@ class DeepSpeedEngine(Module):
return self.save_16bit_model(save_dir, save_filename)
def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"):
r"""Save 16bit model weights
"""
Save 16bit model weights
This method saves the 16bit model weights at the desired destination.
Arguments:
save_dir: Required. Directory for saving the model
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``
Returns:
``True`` when a model has been saved, ``False`` otherwise. It will not be saved if
stage3_gather_16bit_weights_on_model_save is ``False``.
Important: all processes must call this method and not just the process with rank 0. It is
because the processes need to work in sync to gather the weights. This method will hang
waiting to synchronize with other processes if it's called just for the process with rank 0.
"""
path = os.path.join(save_dir, save_filename)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册