未验证 提交 2887349c 编写于 作者: A alexandremuzio 提交者: GitHub

Adding Tutel to MoE layer (#1528)

Co-authored-by: NAlex Muzio <alferre@microsoft.com>
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 cf1f1601
...@@ -25,7 +25,9 @@ class MoE(torch.nn.Module): ...@@ -25,7 +25,9 @@ class MoE(torch.nn.Module):
eval_capacity_factor=1., eval_capacity_factor=1.,
min_capacity=4, min_capacity=4,
noisy_gate_policy: typing.Optional[str] = None, noisy_gate_policy: typing.Optional[str] = None,
drop_tokens: bool = True): drop_tokens: bool = True,
use_rts=True,
use_tutel: bool = False):
"""Initialize an MoE layer. """Initialize an MoE layer.
Arguments: Arguments:
...@@ -44,6 +46,12 @@ class MoE(torch.nn.Module): ...@@ -44,6 +46,12 @@ class MoE(torch.nn.Module):
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'. noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
use_rts (bool, optional): default=True, whether to use Random Token Selection.
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
""" """
super(MoE, self).__init__() super(MoE, self).__init__()
...@@ -68,10 +76,12 @@ class MoE(torch.nn.Module): ...@@ -68,10 +76,12 @@ class MoE(torch.nn.Module):
eval_capacity_factor, eval_capacity_factor,
min_capacity, min_capacity,
noisy_gate_policy, noisy_gate_policy,
drop_tokens), drop_tokens,
use_rts),
experts, experts,
num_local_experts, num_local_experts,
group=groups.get_expert_parallel_group()) group=groups.get_expert_parallel_group(),
use_tutel=use_tutel)
def forward(self, hidden_states, used_token=None): def forward(self, hidden_states, used_token=None):
""" MoE forward """ MoE forward
......
...@@ -13,6 +13,7 @@ Copyright 2021 The Microsoft DeepSpeed Team ...@@ -13,6 +13,7 @@ Copyright 2021 The Microsoft DeepSpeed Team
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.utils import logger, log_dist
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union, cast from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union, cast
import time import time
...@@ -31,6 +32,16 @@ uniform_map: Dict[torch.device, Callable] = {} ...@@ -31,6 +32,16 @@ uniform_map: Dict[torch.device, Callable] = {}
gumbel_map: Dict[torch.device, Callable] = {} gumbel_map: Dict[torch.device, Callable] = {}
exp_selection_uniform_map: Dict[torch.device, Callable] = {} exp_selection_uniform_map: Dict[torch.device, Callable] = {}
try:
# To enable Tutel MoE optimizations:
# python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
from tutel import moe as tutel_moe
TUTEL_INSTALLED = True
except ImportError:
# Fail silently so we don't spam logs unnecessarily if user isn't using tutel
TUTEL_INSTALLED = False
pass
def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
""" """
...@@ -137,9 +148,11 @@ def top1gating(logits: torch.Tensor, ...@@ -137,9 +148,11 @@ def top1gating(logits: torch.Tensor,
min_capacity: int, min_capacity: int,
used_token: torch.Tensor = None, used_token: torch.Tensor = None,
noisy_gate_policy: Optional[str] = None, noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True) -> Tuple[Tensor, drop_tokens: bool = True,
Tensor, use_rts: bool = True,
Tensor]: use_tutel: bool = False) -> Tuple[Tensor,
Tensor,
Tensor]:
"""Implements Top1Gating on logits.""" """Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample': if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
...@@ -179,31 +192,48 @@ def top1gating(logits: torch.Tensor, ...@@ -179,31 +192,48 @@ def top1gating(logits: torch.Tensor,
ce = torch.mean(mask1.float(), dim=0) ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.sum(me * ce) * num_experts l_aux = torch.sum(me * ce) * num_experts
uniform = exp_selection_uniform_map.get(logits.device) # Random Token Selection
if uniform is None: if use_rts:
uniform = torch.distributions.uniform.Uniform( uniform = exp_selection_uniform_map.get(logits.device)
low=torch.tensor(0.0, if uniform is None:
device=logits.device), uniform = torch.distributions.uniform.Uniform(
high=torch.tensor(1.0, low=torch.tensor(0.0,
device=logits.device)).rsample device=logits.device),
exp_selection_uniform_map[logits.device] = uniform high=torch.tensor(1.0,
device=logits.device)).rsample
exp_selection_uniform_map[logits.device] = uniform
mask1_rand = mask1 * uniform(mask1.shape)
mask1_rand = mask1 * uniform(mask1.shape) assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." _, top_idx = torch.topk(mask1_rand, k=capacity, dim=0)
_, top_idx = torch.topk(mask1_rand, k=capacity, dim=0) new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) if use_tutel:
# Tutel doesn't support index values masked with zero
# so we need to replace masked indices with -1
indices_mask = mask1.sum(dim=1) * num_experts - 1
indices1_s = torch.min(indices1_s, indices_mask)
# Compute locations in capacity buffer # Compute locations in capacity buffer
locations1 = torch.cumsum(new_mask1, dim=0) - 1 if use_tutel:
locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
else:
locations1 = torch.cumsum(mask1, dim=0) - 1
if use_tutel:
gates1_s = (gates * mask1).sum(dim=1)
locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,], exp_counts
# Store the capacity location for each token # Store the capacity location for each token
locations1_s = torch.sum(locations1 * new_mask1, dim=1) locations1_s = torch.sum(locations1 * mask1, dim=1)
# Normalize gate probabilities # Normalize gate probabilities
mask1_float = new_mask1.float() mask1_float = mask1.float()
gates = gates * mask1_float gates = gates * mask1_float
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float() locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float()
...@@ -314,7 +344,8 @@ class TopKGate(torch.nn.Module): ...@@ -314,7 +344,8 @@ class TopKGate(torch.nn.Module):
eval_capacity_factor: float = 1.0, eval_capacity_factor: float = 1.0,
min_capacity: int = 4, min_capacity: int = 4,
noisy_gate_policy: Optional[str] = None, noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True) -> None: drop_tokens: bool = True,
use_rts: bool = True) -> None:
super().__init__() super().__init__()
# Only top-1 and top-2 are supported at the moment. # Only top-1 and top-2 are supported at the moment.
...@@ -330,14 +361,15 @@ class TopKGate(torch.nn.Module): ...@@ -330,14 +361,15 @@ class TopKGate(torch.nn.Module):
self.wall_clock_breakdown = False self.wall_clock_breakdown = False
self.gate_time = 0.0 self.gate_time = 0.0
self.drop_tokens = drop_tokens self.drop_tokens = drop_tokens
self.use_rts = use_rts
def forward( def forward(
self, self,
input: torch.Tensor, input: torch.Tensor,
used_token: torch.Tensor = None used_token: torch.Tensor = None,
) -> Tuple[Tensor, use_tutel: bool = False) -> Tuple[Tensor,
Tensor, Tensor,
Tensor]: # type: ignore Tensor]: # type: ignore
if self.wall_clock_breakdown: if self.wall_clock_breakdown:
self.timers('TopKGate').start() self.timers('TopKGate').start()
...@@ -357,7 +389,9 @@ class TopKGate(torch.nn.Module): ...@@ -357,7 +389,9 @@ class TopKGate(torch.nn.Module):
self.min_capacity, self.min_capacity,
used_token, used_token,
self.noisy_gate_policy if self.training else None, self.noisy_gate_policy if self.training else None,
self.drop_tokens) self.drop_tokens,
self.use_rts,
use_tutel)
else: else:
gate_output = top2gating( gate_output = top2gating(
...@@ -392,7 +426,8 @@ class MOELayer(Base): ...@@ -392,7 +426,8 @@ class MOELayer(Base):
gate: Module, gate: Module,
experts: Module, experts: Module,
num_local_experts: int, num_local_experts: int,
group: Optional[Any] = None) -> None: group: Optional[Any] = None,
use_tutel: bool = False) -> None:
super().__init__() super().__init__()
self.gate = gate self.gate = gate
self.experts = experts self.experts = experts
...@@ -405,6 +440,14 @@ class MOELayer(Base): ...@@ -405,6 +440,14 @@ class MOELayer(Base):
self.timers = SynchronizedWallClockTimer() self.timers = SynchronizedWallClockTimer()
self.wall_clock_breakdown = False self.wall_clock_breakdown = False
self.use_tutel = use_tutel and TUTEL_INSTALLED
if self.use_tutel:
logger.info('Using Tutel optimizations.')
elif use_tutel and not TUTEL_INSTALLED:
logger.warning("Tutel optimization requested but not installed. "
"Proceeding without Tutel.")
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
if self.wall_clock_breakdown: if self.wall_clock_breakdown:
...@@ -418,11 +461,23 @@ class MOELayer(Base): ...@@ -418,11 +461,23 @@ class MOELayer(Base):
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
reshaped_input = input[0].reshape(-1, d_model) reshaped_input = input[0].reshape(-1, d_model)
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) if self.use_tutel:
self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
dispatched_input = einsum("sec,sm->ecm", S, M = reshaped_input.size(0), reshaped_input.size(1)
dispatch_mask.type_as(input[0]),
reshaped_input) if not hasattr(self, '_tutel_dispatcher'):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(
E,
C,
M,
dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else:
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
dispatched_input = einsum("sec,sm->ecm",
dispatch_mask.type_as(input[0]),
reshaped_input)
if self.wall_clock_breakdown: if self.wall_clock_breakdown:
self.timers('falltoall').start() self.timers('falltoall').start()
...@@ -455,9 +510,12 @@ class MOELayer(Base): ...@@ -455,9 +510,12 @@ class MOELayer(Base):
-1, -1,
d_model) d_model)
combined_output = einsum("sec,ecm->sm", if self.use_tutel:
combine_weights.type_as(input[0]), combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
expert_output) else:
combined_output = einsum("sec,ecm->sm",
combine_weights.type_as(input[0]),
expert_output)
a = combined_output.reshape(input[0].shape) a = combined_output.reshape(input[0].shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册