未验证 提交 c82756cd 编写于 作者: S Shaden Smith 提交者: GitHub

readthedocs upgrade (#402)

上级 15ca99c6
......@@ -14,6 +14,7 @@ deepspeed.egg-info/
# Website
docs/_site/
docs/build
docs/code-docs/source/_build
docs/code-docs/_build
docs/code-docs/build
.sass-cache/
......
......@@ -841,6 +841,13 @@ class DeepSpeedEngine(Module):
return loss
def is_gradient_accumulation_boundary(self):
"""Query whether the current micro-batch is at the boundary of
gradient accumulation, and thus will trigger gradient reductions and
an optimizer step.
Returns:
bool: if the current step is a gradient accumulation boundary.
"""
return (self.micro_steps + 1) % \
self.gradient_accumulation_steps() == 0
......
......@@ -43,10 +43,10 @@ def _tensor_bytes(tensor):
class PipelineEngine(DeepSpeedEngine):
""" A model wrapper for pipeline-parallel execution.
""" A training engine hybrid pipeline, data, and model parallel training.
Parallelism is achieved by executing micro-batches in a pipelined fashion with
gradient accumulation.
This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule`
is provided.
"""
def __init__(self, *super_args, **super_kwargs):
super().__init__(*super_args, **super_kwargs)
......@@ -227,10 +227,28 @@ class PipelineEngine(DeepSpeedEngine):
self.num_pipe_buffers = num_buffers
def train_batch(self, data_iter=None):
"""Progress the pipeline to train the next batch of data.
"""Progress the pipeline to train the next batch of data. The engine will ingest
``self.train_batch_size()`` total samples collectively across all workers.
An iterator that over training data should be provided as an argument
unless ``deepspeed.initialize()`` was provided a training set. In that event,
the training data will automatically be read.
.. warning::
A total of ``self.gradient_accumulation_steps()`` entries will be pulled
from ``data_iter`` by each pipeline. There must be sufficient
data left in ``data_iter`` or else a ``StopIteration`` will halt training.
DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
that wraps data loaders to automatically restart upon a ``StopIteration``.
Args:
data_iter (Iterator, optional): Iterator of training data.
Returns:
The arithmetic mean of the losses over all micro-batches.
The arithmetic mean of the losses computed this batch.
"""
if not torch._C.is_grad_enabled():
raise RuntimeError(
......@@ -286,7 +304,9 @@ class PipelineEngine(DeepSpeedEngine):
return self.agg_train_loss
def eval_batch(self, data_iter):
"""Evaluate the pipeline on a batch of data from ``data_iter``.
"""Evaluate the pipeline on a batch of data from ``data_iter``. The
engine will evaluate ``self.train_batch_size()`` total samples
collectively across all workers.
This method is equivalent to:
......@@ -296,9 +316,21 @@ class PipelineEngine(DeepSpeedEngine):
with torch.no_grad():
output = module(batch)
.. warning::
A total of ``self.gradient_accumulation_steps()`` entries will be pulled
from ``data_iter`` by each pipeline. There must be sufficient
data left in ``data_iter`` or else a ``StopIteration`` will halt training.
DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
that wraps data loaders to automatically restart upon a ``StopIteration``.
Args:
data_iter (Iterator): Iterator of data to evaluate.
Returns:
The arithmetic mean of the losses over all micro-batches.
The arithmetic mean of the losses computed this batch.
"""
self.module.eval()
self.total_loss = None
......@@ -331,6 +363,14 @@ class PipelineEngine(DeepSpeedEngine):
return self.agg_eval_loss
def is_first_stage(self):
"""True if this process is in the first stage in the pipeline."""
return self.stage_id == 0
def is_last_stage(self):
"""True if this process is in the last stage in the pipeline."""
return self.stage_id == self.num_stages - 1
def _aggregate_total_loss(self):
# Scale loss, average among DP ranks, and bcast loss to the rest of my DP group
if self.is_last_stage():
......@@ -364,7 +404,7 @@ class PipelineEngine(DeepSpeedEngine):
return agg_loss
def set_dataloader(self, loader):
""" Store a DataLoader to sample for training data. """
""""""
if self.is_first_stage() or self.is_last_stage():
self.training_dataloader = loader
self.data_iterator = iter(self.training_dataloader)
......@@ -993,12 +1033,15 @@ class PipelineEngine(DeepSpeedEngine):
return buffers
def forward(self, *args, **kwargs):
"""Disabled for pipeline parallel training. See ``train_batch()``. """
raise PipelineError("Only train_batch() is accessible in pipeline mode.")
def backward(self, *args, **kwargs):
"""Disabled for pipeline parallel training. See ``train_batch()``. """
raise PipelineError("Only train_batch() is accessible in pipeline mode.")
def step(self, *args, **kwargs):
"""Disabled for pipeline parallel training. See ``train_batch()``. """
raise PipelineError("Only train_batch() is accessible in pipeline mode.")
def mem_status(self, msg, print_rank=-1, reset_max=False):
......@@ -1084,14 +1127,6 @@ class PipelineEngine(DeepSpeedEngine):
self.module.load_state_dir(state_dict, strict=strict)
def is_first_stage(self):
"""True if this process is in the first stage in the pipeline."""
return self.stage_id == 0
def is_last_stage(self):
"""True if this process is in the last stage in the pipeline."""
return self.stage_id == self.num_stages - 1
# A map of PipeInstruction types to methods. Each method will be executed with the
# kwargs provided to the PipeInstruction from the scheduler.
_INSTRUCTION_MAP = {
......
......@@ -26,6 +26,8 @@ class LayerSpec:
LayerSpec stores the type information and parameters for each stage in a
PipelineModule. For example:
.. code-block:: python
nn.Sequence(
torch.nn.Linear(self.in_dim, self.hidden_dim, bias=False),
torch.nn.Linear(self.hidden_hidden, self.out_dim)
......@@ -33,6 +35,8 @@ class LayerSpec:
becomes
.. code-block:: python
layer_specs = [
LayerSpec(torch.nn.Linear, self.in_dim, self.hidden_dim, bias=False),
LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)]
......@@ -79,44 +83,46 @@ class TiedLayerSpec(LayerSpec):
class PipelineModule(nn.Module):
"""Base class for modules to be parallelized with pipeline parallelism.
Users should subclass PipelineModule and provide layer_specs(), which returns a list
of LayerSpec objects. Thes sequence of layers represents the pipeline-parallel model.
After initialization, a PipelineModule can be used as a traditional torch.nn.Module.
The forward pass is already provided by this base class. The key assumption is that
the output of each layer can be directly fed as input to the next, like a
torch.nn.Sequence.
The key constraint that enables pipeline parallelism is the representation of the
forward pass as a sequence of layers (i.e., stages) and the enforcement of a
simple interface between them.
Example:
class LinearPipeline(PipelineModule):
def __init__(self, in_dim, hidden_dim, out_dim):
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.out_dim = out_dim
super().__init__()
def layer_specs(self):
return [LayerSpec(torch.nn.Linear, self.in_dim, self.hidden_dim, bias=False),
LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)]
"""
def __init__(self,
layers,
num_stages=None,
loss_fn=None,
topology=None,
loss_fn=None,
seed_layers=False,
seed_fn=None,
base_seed=1234,
partition_method='parameters',
activation_checkpoint_interval=0,
activation_checkpoint_func=checkpointing.checkpoint):
"""Modules to be parallelized with pipeline parallelism.
The key constraint that enables pipeline parallelism is the
representation of the forward pass as a sequence of layers
and the enforcement of a simple interface between them. The
forward pass is implicitly defined by the module ``layers``. The key
assumption is that the output of each layer can be directly fed as
input to the next, like a ``torch.nn.Sequence``. The forward pass is
implicitly:
.. code-block:: python
def forward(self, inputs):
x = inputs
for layer in self.layers:
x = layer(x)
return x
Args:
layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module.
num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided.
topology (``deepseed.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``.
loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)``
base_seed (int, optional): [description]. Defaults to 1234.
partition_method (str, optional): [description]. Defaults to 'parameters'.
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
"""
super().__init__()
if num_stages is None and topology is None:
......@@ -488,7 +494,6 @@ class PipelineModule(nn.Module):
self._local_stop = stop
def set_checkpoint_interval(self, interval):
""" Checkpoint activations after each ``interval`` layers. Use 0 to disable. """
assert interval >= 0
self.checkpoint_interval = interval
......
......@@ -20,7 +20,7 @@ copyright = '2020, Microsoft'
author = 'Microsoft'
# The full version, including alpha/beta/rc tags
release = '0.1.0'
release = '0.3.0'
master_doc = 'index'
......
deepspeed.pt package
====================
Submodules
----------
deepspeed.pt.deepspeed\_config module
-------------------------------------
.. automodule:: deepspeed.pt.deepspeed_config
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_constants module
----------------------------------------
.. automodule:: deepspeed.pt.deepspeed_constants
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_csr\_tensor module
------------------------------------------
.. automodule:: deepspeed.pt.deepspeed_csr_tensor
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_dataloader module
-----------------------------------------
.. automodule:: deepspeed.pt.deepspeed_dataloader
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_fused\_lamb module
------------------------------------------
.. automodule:: deepspeed.pt.deepspeed_fused_lamb
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_launch module
-------------------------------------
.. automodule:: deepspeed.pt.deepspeed_launch
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_light module
------------------------------------
.. automodule:: deepspeed.pt.deepspeed_light
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_lr\_schedules module
--------------------------------------------
.. automodule:: deepspeed.pt.deepspeed_lr_schedules
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_run module
----------------------------------
.. automodule:: deepspeed.pt.deepspeed_run
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_timer module
------------------------------------
.. automodule:: deepspeed.pt.deepspeed_timer
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_utils module
------------------------------------
.. automodule:: deepspeed.pt.deepspeed_utils
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.deepspeed\_zero\_optimizer module
----------------------------------------------
.. automodule:: deepspeed.pt.deepspeed_zero_optimizer
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.fp16\_optimizer module
-----------------------------------
.. automodule:: deepspeed.pt.fp16_optimizer
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.fp16\_unfused\_optimizer module
--------------------------------------------
.. automodule:: deepspeed.pt.fp16_unfused_optimizer
:members:
:undoc-members:
:show-inheritance:
deepspeed.pt.loss\_scaler module
--------------------------------
.. automodule:: deepspeed.pt.loss_scaler
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: deepspeed.pt
:members:
:undoc-members:
:show-inheritance:
deepspeed package
=================
DeepSpeed
=========
Subpackages
-----------
.. toctree::
:maxdepth: 4
deepspeed.pt
Submodules
----------
deepspeed.git\_version\_info module
-----------------------------------
.. automodule:: deepspeed.git_version_info
:members:
:undoc-members:
:show-inheritance:
deepspeed.install\_config module
--------------------------------
.. automodule:: deepspeed.install_config
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: deepspeed
:members:
......
......@@ -40,7 +40,6 @@ Pipeline Parallelism
:maxdepth: 2
pipeline
pipeline-extending
Indices and tables
......
......@@ -5,8 +5,8 @@ DeepSpeed provides routines for checkpointing model state during training.
Loading Training Checkpoints
----------------------------
.. autofunction:: deepspeed.DeepSpeedLight.load_checkpoint
.. autofunction:: deepspeed.DeepSpeedEngine.load_checkpoint
Saving Training Checkpoints
---------------------------
.. autofunction:: deepspeed.DeepSpeedLight.save_checkpoint
.. autofunction:: deepspeed.DeepSpeedEngine.save_checkpoint
deepspeed
=========
.. toctree::
:maxdepth: 4
deepspeed
Extending Pipeline Parallelism
==============================
.. automodule:: deepspeed.runtime.pipe.schedule
:members:
Pipeline Parallelism
====================
Model Specification
--------------------
.. autoclass:: deepspeed.pipe.PipelineModule
:members:
.. autoclass:: deepspeed.pipe.LayerSpec
:members:
.. autoclass:: deepspeed.pipe.TiedLayerSpec
:members:
Training
--------
.. automodule:: deepspeed.runtime.pipe.engine
:members:
.. automodule:: deepspeed.runtime.pipe.topology
Extending Pipeline Parallelism
------------------------------
.. automodule:: deepspeed.runtime.pipe.schedule
:members:
Training API
============
:func:`deepspeed.initialize` returns a *model engine* in its first argument
of type ``DeepSpeedLight``. This engine is used to progress training:
:func:`deepspeed.initialize` returns a *training engine* in its first argument
of type :class:`DeepSpeedEngine`. This engine is used to progress training:
.. code-block:: python
......@@ -18,12 +18,17 @@ of type ``DeepSpeedLight``. This engine is used to progress training:
Forward Propagation
-------------------
.. autofunction:: deepspeed.DeepSpeedLight.forward
.. autofunction:: deepspeed.DeepSpeedEngine.forward
Backward Propagation
--------------------
.. autofunction:: deepspeed.DeepSpeedLight.backward
.. autofunction:: deepspeed.DeepSpeedEngine.backward
Optimizer Step
--------------
.. autofunction:: deepspeed.DeepSpeedLight.step
.. autofunction:: deepspeed.DeepSpeedEngine.step
Gradient Accumulation
---------------------
.. autofunction:: deepspeed.DeepSpeedEngine.is_gradient_accumulation_boundary
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册