未验证 提交 2887349c 编写于 作者: A alexandremuzio 提交者: GitHub

Adding Tutel to MoE layer (#1528)

Co-authored-by: NAlex Muzio <alferre@microsoft.com>
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 cf1f1601
......@@ -25,7 +25,9 @@ class MoE(torch.nn.Module):
eval_capacity_factor=1.,
min_capacity=4,
noisy_gate_policy: typing.Optional[str] = None,
drop_tokens: bool = True):
drop_tokens: bool = True,
use_rts=True,
use_tutel: bool = False):
"""Initialize an MoE layer.
Arguments:
......@@ -44,6 +46,12 @@ class MoE(torch.nn.Module):
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
use_rts (bool, optional): default=True, whether to use Random Token Selection.
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
"""
super(MoE, self).__init__()
......@@ -68,10 +76,12 @@ class MoE(torch.nn.Module):
eval_capacity_factor,
min_capacity,
noisy_gate_policy,
drop_tokens),
drop_tokens,
use_rts),
experts,
num_local_experts,
group=groups.get_expert_parallel_group())
group=groups.get_expert_parallel_group(),
use_tutel=use_tutel)
def forward(self, hidden_states, used_token=None):
""" MoE forward
......
......@@ -13,6 +13,7 @@ Copyright 2021 The Microsoft DeepSpeed Team
# LICENSE file in the root directory of this source tree.
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.utils import logger, log_dist
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union, cast
import time
......@@ -31,6 +32,16 @@ uniform_map: Dict[torch.device, Callable] = {}
gumbel_map: Dict[torch.device, Callable] = {}
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
try:
# To enable Tutel MoE optimizations:
# python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
from tutel import moe as tutel_moe
TUTEL_INSTALLED = True
except ImportError:
# Fail silently so we don't spam logs unnecessarily if user isn't using tutel
TUTEL_INSTALLED = False
pass
def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
"""
......@@ -137,9 +148,11 @@ def top1gating(logits: torch.Tensor,
min_capacity: int,
used_token: torch.Tensor = None,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True) -> Tuple[Tensor,
Tensor,
Tensor]:
drop_tokens: bool = True,
use_rts: bool = True,
use_tutel: bool = False) -> Tuple[Tensor,
Tensor,
Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
......@@ -179,31 +192,48 @@ def top1gating(logits: torch.Tensor,
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.sum(me * ce) * num_experts
uniform = exp_selection_uniform_map.get(logits.device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0,
device=logits.device),
high=torch.tensor(1.0,
device=logits.device)).rsample
exp_selection_uniform_map[logits.device] = uniform
# Random Token Selection
if use_rts:
uniform = exp_selection_uniform_map.get(logits.device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0,
device=logits.device),
high=torch.tensor(1.0,
device=logits.device)).rsample
exp_selection_uniform_map[logits.device] = uniform
mask1_rand = mask1 * uniform(mask1.shape)
mask1_rand = mask1 * uniform(mask1.shape)
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
_, top_idx = torch.topk(mask1_rand, k=capacity, dim=0)
_, top_idx = torch.topk(mask1_rand, k=capacity, dim=0)
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
if use_tutel:
# Tutel doesn't support index values masked with zero
# so we need to replace masked indices with -1
indices_mask = mask1.sum(dim=1) * num_experts - 1
indices1_s = torch.min(indices1_s, indices_mask)
# Compute locations in capacity buffer
locations1 = torch.cumsum(new_mask1, dim=0) - 1
if use_tutel:
locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
else:
locations1 = torch.cumsum(mask1, dim=0) - 1
if use_tutel:
gates1_s = (gates * mask1).sum(dim=1)
locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,], exp_counts
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * new_mask1, dim=1)
locations1_s = torch.sum(locations1 * mask1, dim=1)
# Normalize gate probabilities
mask1_float = new_mask1.float()
mask1_float = mask1.float()
gates = gates * mask1_float
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float()
......@@ -314,7 +344,8 @@ class TopKGate(torch.nn.Module):
eval_capacity_factor: float = 1.0,
min_capacity: int = 4,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True) -> None:
drop_tokens: bool = True,
use_rts: bool = True) -> None:
super().__init__()
# Only top-1 and top-2 are supported at the moment.
......@@ -330,14 +361,15 @@ class TopKGate(torch.nn.Module):
self.wall_clock_breakdown = False
self.gate_time = 0.0
self.drop_tokens = drop_tokens
self.use_rts = use_rts
def forward(
self,
input: torch.Tensor,
used_token: torch.Tensor = None
) -> Tuple[Tensor,
Tensor,
Tensor]: # type: ignore
self,
input: torch.Tensor,
used_token: torch.Tensor = None,
use_tutel: bool = False) -> Tuple[Tensor,
Tensor,
Tensor]: # type: ignore
if self.wall_clock_breakdown:
self.timers('TopKGate').start()
......@@ -357,7 +389,9 @@ class TopKGate(torch.nn.Module):
self.min_capacity,
used_token,
self.noisy_gate_policy if self.training else None,
self.drop_tokens)
self.drop_tokens,
self.use_rts,
use_tutel)
else:
gate_output = top2gating(
......@@ -392,7 +426,8 @@ class MOELayer(Base):
gate: Module,
experts: Module,
num_local_experts: int,
group: Optional[Any] = None) -> None:
group: Optional[Any] = None,
use_tutel: bool = False) -> None:
super().__init__()
self.gate = gate
self.experts = experts
......@@ -405,6 +440,14 @@ class MOELayer(Base):
self.timers = SynchronizedWallClockTimer()
self.wall_clock_breakdown = False
self.use_tutel = use_tutel and TUTEL_INSTALLED
if self.use_tutel:
logger.info('Using Tutel optimizations.')
elif use_tutel and not TUTEL_INSTALLED:
logger.warning("Tutel optimization requested but not installed. "
"Proceeding without Tutel.")
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
if self.wall_clock_breakdown:
......@@ -418,11 +461,23 @@ class MOELayer(Base):
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
reshaped_input = input[0].reshape(-1, d_model)
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
dispatched_input = einsum("sec,sm->ecm",
dispatch_mask.type_as(input[0]),
reshaped_input)
if self.use_tutel:
self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
S, M = reshaped_input.size(0), reshaped_input.size(1)
if not hasattr(self, '_tutel_dispatcher'):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(
E,
C,
M,
dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else:
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
dispatched_input = einsum("sec,sm->ecm",
dispatch_mask.type_as(input[0]),
reshaped_input)
if self.wall_clock_breakdown:
self.timers('falltoall').start()
......@@ -455,9 +510,12 @@ class MOELayer(Base):
-1,
d_model)
combined_output = einsum("sec,ecm->sm",
combine_weights.type_as(input[0]),
expert_output)
if self.use_tutel:
combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
else:
combined_output = einsum("sec,ecm->sm",
combine_weights.type_as(input[0]),
expert_output)
a = combined_output.reshape(input[0].shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册