Switch Transformer

This is a miniature implementation of the paper Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Our implementation only has a few million parameters and doesn’t do model parallel distributed training. It does single GPU training but we implement the concept of switching as described in the paper.

The Switch Transformer uses different parameters for each token by switching among parameters, based on the token. So only a fraction of parameters is chosen for each token, so you can have more parameters but less computational cost.

The switching happens at the Position-wise Feedforward network (FFN) of of each transformer block. Position-wise feedforward network is a two sequentially fully connected layers. In switch transformer we have multiple FFNs (multiple experts) and we chose which one to use based on a router. The outputs a set of probabilities for picking a FFN, and we pick the one with highest probability and only evaluates that. So essentially the computational cost is same as having a single FFN. In our implementation this doesn’t parallelize well when you have many or large FFNs since it’s all happening on a single GPU. In a distributed setup you would have each FFN (each very large) on a different device.

The paper introduces another loss term to balance load among the experts (FFNs) and discusses dropping tokens when routing is not balanced.

Here’s the training code and a notebook for training a switch transformer on Tiny Shakespeare dataset.

Open In Colab View Run

39import torch
40from torch import nn
41
42from labml_helpers.module import Module
43from labml_nn.transformers.mha import MultiHeadAttention
44from labml_nn.transformers.feed_forward import FeedForward
45from labml_nn.utils import clone_module_list

Routing among multiple FFNs

48class SwitchFeedForward(Module):
  • capacity_factor is the capacity of each expert as a factor relative to ideally balanced load
  • drop_tokens specifies whether to drop tokens if more tokens are routed to an expert than the capacity
  • is_scale_prob specifies whether to multiply the input to the FFN by the routing probability
  • n_experts is the number of experts
  • expert is the expert layer, a FFN module
  • d_model is the number of features in a token embedding
  • d_ff is the number of features in the hidden layer of the FFN
  • dropout is dropout probability in the FFN
53    def __init__(self, *,
54                 capacity_factor: float,
55                 drop_tokens: bool,
56                 is_scale_prob: bool,
57                 n_experts: int,
58                 expert: FeedForward,
59                 d_model: int):
70        super().__init__()
71
72        self.capacity_factor = capacity_factor
73        self.is_scale_prob = is_scale_prob
74        self.n_experts = n_experts
75        self.drop_tokens = drop_tokens

make copies of the FFNs

78        self.experts = clone_module_list(expert, n_experts)

Routing layer and softmax

80        self.switch = nn.Linear(d_model, n_experts)
81        self.softmax = nn.Softmax(dim=-1)
  • x is the input to the switching module with shape [seq_len, batch_size, d_model]
83    def __call__(self, x: torch.Tensor):

Capture the shape to change shapes later

89        seq_len, batch_size, d_model = x.shape

Flatten the sequence and batch dimensions

91        x = x.view(-1, d_model)

Get routing probabilities for each of the tokens. where $N$ is the number of experts n_experts and $h(\cdot)$ is the linear transformation of token embeddings.

97        route_prob = self.softmax(self.switch(x))

Get the maximum routing probabilities and the routes. We route to the expert with highest probability

101        route_prob_max, routes = torch.max(route_prob, dim=-1)

Scale the inputs to the experts by the routing probabilities

104        if self.is_scale_prob:
105            factor = route_prob_max

Don’t scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow

107        else:
108            factor = route_prob_max / route_prob_max.detach()

Multiply by the scaling factor

110        x = x * factor.view(-1, 1)

Get indexes of tokens going to each expert

113        indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]

Initialize an empty tensor to store outputs

116        final_output = x.new_zeros(x.shape)

Capacity of each expert.

122        capacity = int(self.capacity_factor * len(x) / self.n_experts)

Number of tokens routed to each expert.

124        counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])

Initialize an empty list of dropped tokens

127        dropped = []

Only drop tokens if drop_tokens is True.

129        if self.drop_tokens:

Drop tokens in each of the experts

131            for i in range(self.n_experts):

Ignore if the expert is not over capacity

133                if len(indexes_list[i]) <= capacity:
134                    continue

Shuffle indexes before dropping

136                indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]

Collect the tokens over capacity as dropped tokens

138                dropped.append(indexes_list[i][capacity:])

Keep only the tokens upto the capacity of the expert

140                indexes_list[i] = indexes_list[i][:capacity]

Get outputs of the expert FFNs

143        route_outputs = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)]

Assign to final output

146        for i in range(self.n_experts):
147            final_output[indexes_list[i], :] = route_outputs[i]

Pass through the dropped tokens

150        if dropped:
151            dropped = torch.cat(dropped)
152            final_output[dropped, :] = x[dropped, :]

Change the shape of the final output back to [seq_len, batch_size, d_model]

155        final_output = final_output.view(seq_len, batch_size, d_model)

Return * the final output * number of tokens routed to each expert * sum of probabilities for each expert * number of tokens dropped These are used for the load balancing loss and logging

163        return final_output, counts, route_prob.sum(0), len(dropped)

Switch Transformer Block

This is same as normal transformer block with handling extra outputs of switch feedforward module.

166class SwitchTransformerLayer(Module):
  • d_model is the token embedding size
  • attn is the attention module
  • feed_forward is the feed forward module (which is the switching module in this case)
  • dropout_prob is the probability of dropping out after self attention and FFN
173    def __init__(self, *,
174                 d_model: int,
175                 attn: MultiHeadAttention,
176                 feed_forward: SwitchFeedForward,
177                 dropout_prob: float):
184        super().__init__()
185        self.size = d_model
186        self.attn = attn
187        self.feed_forward = feed_forward
188        self.dropout = nn.Dropout(dropout_prob)
189        self.norm_self_attn = nn.LayerNorm([d_model])
190        self.norm_ff = nn.LayerNorm([d_model])
192    def __call__(self, *,
193                 x: torch.Tensor,
194                 mask: torch.Tensor):

Normalize the vectors before doing self attention

196        z = self.norm_self_attn(x)

Run through self attention, i.e. keys and values are from self

198        self_attn = self.attn(query=z, key=z, value=z, mask=mask)

Add the self attention results

200        x = x + self.dropout(self_attn)

Normalize for feed-forward

203        z = self.norm_ff(x)

Pass through the switching feed-forward network

205        ff, counts, route_prob, n_dropped = self.feed_forward(z)

Add the feed-forward results back

207        x = x + self.dropout(ff)
208
209        return x, counts, route_prob, n_dropped

Switch Transformer

212class SwitchTransformer(Module):
217    def __init__(self, layer: SwitchTransformerLayer, n_layers: int):
218        super().__init__()

Make copies of the transformer layer

220        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

222        self.norm = nn.LayerNorm([layer.size])
224    def __call__(self, x: torch.Tensor, mask: torch.Tensor):

Run through each transformer layer

226        counts, route_prob, n_dropped = [], [], []
227        for layer in self.layers:
228            x, f, p, n_d = layer(x=x, mask=mask)
229            counts.append(f)
230            route_prob.append(p)
231            n_dropped.append(n_d)

Finally, normalize the vectors

233        x = self.norm(x)
235        return x, torch.stack(counts), torch.stack(route_prob), n_dropped