This is a miniature implementation of the paper Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Our implementation only has a few million parameters and doesn’t do model parallel distributed training. It does single GPU training but we implement the concept of switching as described in the paper.
The Switch Transformer uses different parameters for each token by switching among parameters, based on the token. So only a fraction of parameters is chosen for each token, so you can have more parameters but less computational cost.
The switching happens at the Position-wise Feedforward network (FFN) of of each transformer block. Position-wise feedforward network is a two sequentially fully connected layers. In switch transformer we have multiple FFNs (multiple experts) and we chose which one to use based on a router. The outputs a set of probabilities for picking a FFN, and we pick the one with highest probability and only evaluates that. So essentially the computational cost is same as having a single FFN. In our implementation this doesn’t parallelize well when you have many or large FFNs since it’s all happening on a single GPU. In a distributed setup you would have each FFN (each very large) on a different device.
The paper introduces another loss term to balance load among the experts (FFNs) and discusses dropping tokens when routing is not balanced.
Here’s the training code and a notebook for training a switch transformer on Tiny Shakespeare dataset.
39import torch
40from torch import nn
41
42from labml_helpers.module import Module
43from labml_nn.transformers.mha import MultiHeadAttention
44from labml_nn.transformers.feed_forward import FeedForward
45from labml_nn.utils import clone_module_list
48class SwitchFeedForward(Module):
capacity_factor
is the capacity of each expert as a factor relative to ideally balanced loaddrop_tokens
specifies whether to drop tokens if more tokens are routed to an expert than the capacityis_scale_prob
specifies whether to multiply the input to the FFN by the routing probabilityn_experts
is the number of expertsexpert
is the expert layer, a FFN moduled_model
is the number of features in a token embeddingd_ff
is the number of features in the hidden layer of the FFNdropout
is dropout probability in the FFN53 def __init__(self, *,
54 capacity_factor: float,
55 drop_tokens: bool,
56 is_scale_prob: bool,
57 n_experts: int,
58 expert: FeedForward,
59 d_model: int):
70 super().__init__()
71
72 self.capacity_factor = capacity_factor
73 self.is_scale_prob = is_scale_prob
74 self.n_experts = n_experts
75 self.drop_tokens = drop_tokens
make copies of the FFNs
78 self.experts = clone_module_list(expert, n_experts)
Routing layer and softmax
80 self.switch = nn.Linear(d_model, n_experts)
81 self.softmax = nn.Softmax(dim=-1)
x
is the input to the switching module with shape [seq_len, batch_size, d_model]
83 def __call__(self, x: torch.Tensor):
Capture the shape to change shapes later
89 seq_len, batch_size, d_model = x.shape
Flatten the sequence and batch dimensions
91 x = x.view(-1, d_model)
Get routing probabilities for each of the tokens.
where $N$ is the number of experts n_experts
and
$h(\cdot)$ is the linear transformation of token embeddings.
97 route_prob = self.softmax(self.switch(x))
Get the maximum routing probabilities and the routes. We route to the expert with highest probability
101 route_prob_max, routes = torch.max(route_prob, dim=-1)
Scale the inputs to the experts by the routing probabilities
104 if self.is_scale_prob:
105 factor = route_prob_max
Don’t scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
107 else:
108 factor = route_prob_max / route_prob_max.detach()
Multiply by the scaling factor
110 x = x * factor.view(-1, 1)
Get indexes of tokens going to each expert
113 indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]
Initialize an empty tensor to store outputs
116 final_output = x.new_zeros(x.shape)
Capacity of each expert.
122 capacity = int(self.capacity_factor * len(x) / self.n_experts)
Number of tokens routed to each expert.
124 counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])
Initialize an empty list of dropped tokens
127 dropped = []
Only drop tokens if drop_tokens
is True
.
129 if self.drop_tokens:
Drop tokens in each of the experts
131 for i in range(self.n_experts):
Ignore if the expert is not over capacity
133 if len(indexes_list[i]) <= capacity:
134 continue
Shuffle indexes before dropping
136 indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
Collect the tokens over capacity as dropped tokens
138 dropped.append(indexes_list[i][capacity:])
Keep only the tokens upto the capacity of the expert
140 indexes_list[i] = indexes_list[i][:capacity]
Get outputs of the expert FFNs
143 route_outputs = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)]
Assign to final output
146 for i in range(self.n_experts):
147 final_output[indexes_list[i], :] = route_outputs[i]
Pass through the dropped tokens
150 if dropped:
151 dropped = torch.cat(dropped)
152 final_output[dropped, :] = x[dropped, :]
Change the shape of the final output back to [seq_len, batch_size, d_model]
155 final_output = final_output.view(seq_len, batch_size, d_model)
Return * the final output * number of tokens routed to each expert * sum of probabilities for each expert * number of tokens dropped These are used for the load balancing loss and logging
163 return final_output, counts, route_prob.sum(0), len(dropped)
This is same as normal transformer block with handling extra outputs of switch feedforward module.
166class SwitchTransformerLayer(Module):
d_model
is the token embedding sizeattn
is the attention modulefeed_forward
is the feed forward module (which is the switching module in this case)dropout_prob
is the probability of dropping out after self attention and FFN173 def __init__(self, *,
174 d_model: int,
175 attn: MultiHeadAttention,
176 feed_forward: SwitchFeedForward,
177 dropout_prob: float):
184 super().__init__()
185 self.size = d_model
186 self.attn = attn
187 self.feed_forward = feed_forward
188 self.dropout = nn.Dropout(dropout_prob)
189 self.norm_self_attn = nn.LayerNorm([d_model])
190 self.norm_ff = nn.LayerNorm([d_model])
192 def __call__(self, *,
193 x: torch.Tensor,
194 mask: torch.Tensor):
Normalize the vectors before doing self attention
196 z = self.norm_self_attn(x)
Run through self attention, i.e. keys and values are from self
198 self_attn = self.attn(query=z, key=z, value=z, mask=mask)
Add the self attention results
200 x = x + self.dropout(self_attn)
Normalize for feed-forward
203 z = self.norm_ff(x)
Pass through the switching feed-forward network
205 ff, counts, route_prob, n_dropped = self.feed_forward(z)
Add the feed-forward results back
207 x = x + self.dropout(ff)
208
209 return x, counts, route_prob, n_dropped
212class SwitchTransformer(Module):
217 def __init__(self, layer: SwitchTransformerLayer, n_layers: int):
218 super().__init__()
Make copies of the transformer layer
220 self.layers = clone_module_list(layer, n_layers)
Final normalization layer
222 self.norm = nn.LayerNorm([layer.size])
224 def __call__(self, x: torch.Tensor, mask: torch.Tensor):
Run through each transformer layer
226 counts, route_prob, n_dropped = [], [], []
227 for layer in self.layers:
228 x, f, p, n_d = layer(x=x, mask=mask)
229 counts.append(f)
230 route_prob.append(p)
231 n_dropped.append(n_d)
Finally, normalize the vectors
233 x = self.norm(x)
235 return x, torch.stack(counts), torch.stack(route_prob), n_dropped