未验证 提交 07887f66 编写于 作者: G Gary Miguel 提交者: GitHub

sharded_moe: make top1gating ONNX-exportable (#1578)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 64c2946a
......@@ -22,6 +22,7 @@ import torch
from torch import Tensor
import torch.distributed as dist
from torch.nn import Module, ModuleList
import torch.nn.functional as F
if TYPE_CHECKING:
Base = Module[Tensor]
......@@ -79,12 +80,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
return gumbel(shape)
import torch.distributed as dist
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
@staticmethod
......@@ -102,16 +97,13 @@ class _AllToAll(torch.autograd.Function):
return (None, _AllToAll.apply(ctx.group, *grad_output))
from torch import nn
import torch.nn.functional as F
import math
# einsum rewrites are on par or more performant
# switch can be bubbled up in future
USE_EINSUM = True
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
def einsum(rule, a, b):
if USE_EINSUM:
return torch.einsum(rule, a, b)
......@@ -143,14 +135,47 @@ def einsum(rule, a, b):
return torch.einsum(rule, a, b)
def top1gating(logits: torch.Tensor,
# The following functions are extracted and scripted
# because otherwise during a torch.jit.trace, the non-Tensor
# values used in the calculations get recorded as constants.
# torch.jit.script coerces them into Tensors and preserves
# their dynamic shapes. This enables ONNX export.
# We can't script the entire top1gating function because it
# includes stateful caching logic which is incompatible with ONNX.
@torch.jit.script
def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
# gates has shape of SE
num_tokens = gates.shape[0]
num_experts = gates.shape[1]
# to(torch.int64) works around a bug in torch.onnx.export:
# it should cast k to int64 when converting torch.topk but it doesn't.
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
if capacity < min_capacity:
capacity = min_capacity.to(torch.int64)
return capacity
@torch.jit.script
def _top_idx(source, k):
return torch.topk(source, k=k, dim=0)[1]
@torch.jit.script
def _one_hot_to_float(x, num_classes):
return F.one_hot(x, num_classes=num_classes).float()
def top1gating(logits: Tensor,
capacity_factor: float,
min_capacity: int,
used_token: torch.Tensor = None,
used_token: Tensor = None,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
use_tutel: bool = False) -> Tuple[Tensor,
Tensor,
Tensor,
Tensor]:
"""Implements Top1Gating on logits."""
......@@ -159,19 +184,16 @@ def top1gating(logits: torch.Tensor,
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
# gates has shape of SE
num_tokens = int(gates.shape[0])
num_experts = int(gates.shape[1])
# round-up
capacity = math.ceil((num_tokens / num_experts) * capacity_factor)
if capacity < min_capacity:
capacity = min_capacity
capacity = _capacity(gates,
torch.tensor(capacity_factor),
torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
# noisy gating
indices1_s = torch.argmax(
logits_w_noise if noisy_gate_policy == 'RSample' else gates,
dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# mask only used tokens
......@@ -207,7 +229,7 @@ def top1gating(logits: torch.Tensor,
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
_, top_idx = torch.topk(mask1_rand, k=capacity, dim=0)
top_idx = _top_idx(mask1_rand, capacity)
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1
......@@ -236,7 +258,7 @@ def top1gating(logits: torch.Tensor,
mask1_float = mask1.float()
gates = gates * mask1_float
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float()
locations1_sc = _one_hot_to_float(locations1_s, capacity)
combine_weights = einsum("se,sc->sec", gates, locations1_sc)
dispatch_mask = combine_weights.bool()
......@@ -244,24 +266,22 @@ def top1gating(logits: torch.Tensor,
return l_aux, combine_weights, dispatch_mask, exp_counts
def top2gating(logits: torch.Tensor,
def top2gating(logits: Tensor,
capacity_factor: float) -> Tuple[Tensor,
Tensor,
Tensor,
Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
# logits_fp32 = logits.to(torch.float32)
gates = F.softmax(logits, dim=1)
# gates has shape of SE
num_tokens = int(gates.shape[0])
num_experts = int(gates.shape[1])
# capacity = (2 * num_tokens // num_experts) * capacity_factor
# round-up
capacity = math.ceil((2 * num_tokens / num_experts) * capacity_factor)
capacity = _capacity(gates,
torch.tensor(capacity_factor * 2),
torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# Create a mask for 2nd's expert per token using Gumbel-max trick
......@@ -308,8 +328,8 @@ def top2gating(logits: torch.Tensor,
# Calculate combine_weights and dispatch_mask
gates1 = einsum("s,se->se", gates1_s, mask1_float)
gates2 = einsum("s,se->se", gates2_s, mask2_float)
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float()
locations2_sc = F.one_hot(locations2_s, num_classes=capacity).float()
locations1_sc = _one_hot_to_float(locations1_s, capacity)
locations2_sc = _one_hot_to_float(locations2_s, capacity)
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
combine_weights = combine1_sec + combine2_sec
......@@ -318,7 +338,7 @@ def top2gating(logits: torch.Tensor,
return l_aux, combine_weights, dispatch_mask, exp_counts
class TopKGate(torch.nn.Module):
class TopKGate(Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册