未验证 提交 18713c68 编写于 作者: J Joe Mayer 提交者: GitHub

Updating API docs (#2586)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 377c770a
......@@ -13,6 +13,24 @@ import typing
class MoE(torch.nn.Module):
"""Initialize an MoE layer.
Arguments:
hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
num_experts (int, optional): default=1, the total number of experts per layer.
ep_size (int, optional): default=1, number of ranks in the expert parallel world or group.
k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
use_rts (bool, optional): default=True, whether to use Random Token Selection.
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
"""
def __init__(self,
hidden_size,
expert,
......@@ -28,24 +46,6 @@ class MoE(torch.nn.Module):
use_rts=True,
use_tutel: bool = False,
enable_expert_tensor_parallelism: bool = False):
"""Initialize an MoE layer.
Arguments:
hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
num_experts (int, optional): default=1, the total number of experts per layer.
ep_size (int, optional): default=1, number of ranks in the expert parallel world or group.
k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
use_rts (bool, optional): default=True, whether to use Random Token Selection.
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
"""
super(MoE, self).__init__()
......
......@@ -83,6 +83,40 @@ class TiedLayerSpec(LayerSpec):
class PipelineModule(nn.Module):
"""Modules to be parallelized with pipeline parallelism.
The key constraint that enables pipeline parallelism is the
representation of the forward pass as a sequence of layers
and the enforcement of a simple interface between them. The
forward pass is implicitly defined by the module ``layers``. The key
assumption is that the output of each layer can be directly fed as
input to the next, like a ``torch.nn.Sequence``. The forward pass is
implicitly:
.. code-block:: python
def forward(self, inputs):
x = inputs
for layer in self.layers:
x = layer(x)
return x
.. note::
Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3.
Args:
layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module.
num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided.
topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``.
loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)``
seed_layers(bool, optional): Use a different seed for each layer. Defaults to False.
seed_fn(type, optional): The custom seed generating function. Defaults to random seed generator.
base_seed (int, optional): The starting seed. Defaults to 1234.
partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
"""
def __init__(self,
layers,
num_stages=None,
......@@ -95,37 +129,6 @@ class PipelineModule(nn.Module):
activation_checkpoint_interval=0,
activation_checkpoint_func=checkpointing.checkpoint,
checkpointable_layers=None):
"""Modules to be parallelized with pipeline parallelism.
The key constraint that enables pipeline parallelism is the
representation of the forward pass as a sequence of layers
and the enforcement of a simple interface between them. The
forward pass is implicitly defined by the module ``layers``. The key
assumption is that the output of each layer can be directly fed as
input to the next, like a ``torch.nn.Sequence``. The forward pass is
implicitly:
.. code-block:: python
def forward(self, inputs):
x = inputs
for layer in self.layers:
x = layer(x)
return x
.. note::
Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3.
Args:
layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module.
num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided.
topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``.
loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)``
base_seed (int, optional): [description]. Defaults to 1234.
partition_method (str, optional): [description]. Defaults to 'parameters'.
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
"""
super().__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册