Support padding_idx in the lookup_table_op.
Created by: lcy-seso
The dot product attention can be formulated as:
$$ \text{Attention}(Q, K, V) = \text{softmax}(QK^T)V $$
Above, $Q$, $K$ and $V$ are all 3-D tensor.
- $Q$ has a shape of $[\text{bs} \times N \times D]$
- where $\text{bs}$ is the batch size, $N$ is the max target sequence length in a mini-batch (sequences in a mini-batch are all padded to having the same length), and $D$ is the hidden size.
- $K$ has a shape of $[\text{bs} \times M \times D]$
- where $\text{bs}$ is the batch size, $M$ is the max source sequence length in a mini-batch (sequences in a mini-batch are all padded to having the same length), and $D$ is the hidden size.
- $V$ has a shape $[\text{bs} \times M \times D]$.
- where $\text{bs}$ is the batch size, $M$ is the max source sequence length in a mini-batch (sequences in a mini-batch are all padded to having the same length), and $D$ is the hidden size.
With the above notation in hand, we have:
-
Suppose $W = \text{softmax}(QK^T)$. $W$ is the attention weight with a shape $[\text{bs} \times N \times M]$.
-
Suppose $C = WV$ is the context vector with a shape $[\text{bs} \times N \times D]$.
From the above computation, to use batched matrix multiplication (potentially can be optimized to achieve a better computation efficiency?), each source sequence and the target sequence in one mini-batch have to have the same length (length of source sentence and length of target sentence can be different).
This requires padding sequences in one mini-batch to have the same length:
- The padding had to be fixed to zeros so that it does not affect the softmax normalization.
- The padding should not be changed during training and it does not need gradients.
Torch implements this by making padding_idx
a special token for the look_up_table_op
: http://pytorch.org/docs/0.3.0/nn.html?highlight=embedding#torch.nn.Embedding .
Maybe it can also be implemented as a mask.
The other side:
- If we do not pad sequences in the mini-batch to have the same length, the dot-product attention have to be computed in a
for
loop. I am not sure about the differences in computation speed between padding and no padding.