提交 969df719 编写于 作者: V Varuna Jayasiri

documentation fixes

上级 30bbea41
...@@ -204,7 +204,7 @@ ...@@ -204,7 +204,7 @@
<div class='section-link'> <div class='section-link'>
<a href='#section-10'>#</a> <a href='#section-10'>#</a>
</div> </div>
<p>Length of the memory (for masks)</p> <p>Total length of the memory and compressed memory (for masks)</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">m_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">mem</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="k">if</span> <span class="n">mem</span> <span class="k">else</span> <span class="mi">0</span> <div class="highlight"><pre><span class="lineno">58</span> <span class="n">m_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">mem</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="k">if</span> <span class="n">mem</span> <span class="k">else</span> <span class="mi">0</span>
...@@ -311,7 +311,7 @@ ...@@ -311,7 +311,7 @@
<a href='#section-19'>#</a> <a href='#section-19'>#</a>
</div> </div>
<h2>Configurations</h2> <h2>Configurations</h2>
<p>The default configs can and will be over-ridden when we start the experiment.</p> <p>The default configurations can and will be overridden when we start the experiment.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">87</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">87</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
...@@ -512,7 +512,7 @@ ...@@ -512,7 +512,7 @@
<div class='section-link'> <div class='section-link'>
<a href='#section-37'>#</a> <a href='#section-37'>#</a>
</div> </div>
<p>If it&rsquo;s configured not to use memory</p> <p>If the configurations specify not to use memory</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_len</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">c_mem_len</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <div class="highlight"><pre><span class="lineno">136</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_len</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">c_mem_len</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
...@@ -726,7 +726,7 @@ and $N_m$ is the maximum number of memories we maintain (<code>mem_len</code>).< ...@@ -726,7 +726,7 @@ and $N_m$ is the maximum number of memories we maintain (<code>mem_len</code>).<
<a href='#section-55'>#</a> <a href='#section-55'>#</a>
</div> </div>
<p>Return memories and the memories that were compressed. <p>Return memories and the memories that were compressed.
Memories that were compressed is needed for the reconstruction loss computation.</p> Memories that were compressed are needed for the reconstruction loss computation.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="k">return</span> <span class="n">CompressedMemory</span><span class="p">(</span><span class="n">mem</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">),</span> <span class="n">mem_to_compress</span></pre></div> <div class="highlight"><pre><span class="lineno">195</span> <span class="k">return</span> <span class="n">CompressedMemory</span><span class="p">(</span><span class="n">mem</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">),</span> <span class="n">mem_to_compress</span></pre></div>
......
...@@ -22,29 +22,30 @@ $n_{cm}$ memories, where $c$ is the compression rate. ...@@ -22,29 +22,30 @@ $n_{cm}$ memories, where $c$ is the compression rate.
The compression operation is defined as The compression operation is defined as
$f_c: \mathbb{R}^{nc \times d} \rightarrow \mathbb{R}^{n \times d}$. $f_c: \mathbb{R}^{nc \times d} \rightarrow \mathbb{R}^{n \times d}$.
The paper introduces multiple choices for $f_c$ and we have only implemented The paper introduces multiple choices for $f_c$ and we have only implemented
1D convolution which seems to give best results. 1D convolution which seems to give the best results.
Each layer has a separate compression operation $f_c^{(i)}$ where Each layer has a separate compression operation $f_c^{(i)}$ where
$i$ is the layer number. $i$ is the layer number.
## Training compression operation ## Training compression operation
Since training compression with BPTT requires maintaining Since training compression with BPTT requires maintaining
a very large computational graph (many time steps), paper proposes a very large computational graph (many time steps), the paper proposes
an *auto-encoding loss* and an *attention reconstruction loss*. an *auto-encoding loss* and an *attention reconstruction loss*.
The auto-encoding loss, decodes the original memories from the compressed memories, The auto-encoding loss decodes the original memories from the compressed memories
and calculate the loss. and calculates the loss.
Attention reconstruction loss computes the multi-headed attention results Attention reconstruction loss computes the multi-headed attention results
on the compressed memory and on uncompressed memory and get a mean squared error on the compressed memory and on uncompressed memory and gets a mean squared error
between them. between them.
We have implemented the latter here since it gives better results. We have implemented the latter here since it gives better results.
This implementation uses pre-layer norm while the paper uses post-layer norm. This implementation uses pre-layer normalization
while the paper uses post-layer normalization.
Pre-layer norm does the layer norm before FFN[../feedforward.html) and Pre-layer norm does the layer norm before FFN[../feedforward.html) and
self attention, and the pass through in the residual connection is not normalized. self-attention, and the pass-through in the residual connection is not normalized.
This is supposed to be more stable in standard transformer setups. This is supposed to be more stable in standard transformer setups.
Here's [the training code](experiment.html) and a notebook for training a compressive transformer Here are [the training code](experiment.html) and a notebook for training a compressive transformer
model on Tiny Shakespeare dataset. model on the Tiny Shakespeare dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/compressive/experiment.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/compressive/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=0d9b5338726c11ebb7c80242ac1c0002) [![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=0d9b5338726c11ebb7c80242ac1c0002)
...@@ -219,18 +220,18 @@ class AttentionReconstructionLoss: ...@@ -219,18 +220,18 @@ class AttentionReconstructionLoss:
""" """
## Attention Reconstruction Loss ## Attention Reconstruction Loss
Attention reconstruction loss recreates the self attention output with Attention reconstruction loss recreates the self-attention output with
uncompressed memory and with compressed memory and calculate mean squared error uncompressed memory and with compressed memory and calculates the mean squared error
between the two. It does this without positional encoding. between the two. It does this without positional encoding.
When calculating and training the compression function $f_c$ with attention When calculating and training the compression function $f_c$ with attention
reconstruction loss all parameters but $f_c$ are frozen. reconstruction loss, all parameters but $f_c$ are frozen.
This includes key value projections and bias/scaling after normalization. This includes key/value projections and bias/scaling after normalization.
Since this loss can be computed independently of the cross-entropy-loss of the model Since this loss can be computed independently of the cross-entropy-loss of the model
you can have a separate optimizer that only updates $f_c$. you can have a separate optimizer that only updates $f_c$.
However, we use the same optimizer to update $f_c$ so when calculating However, we use the same optimizer to update $f_c$ so when calculating
attention reconstruction loss we detach all other parameters except $f_c$ attention reconstruction loss, we detach all other parameters except $f_c$
from the gradient computation. from the gradient computation.
""" """
def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]): def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
...@@ -320,7 +321,7 @@ class AttentionReconstructionLoss: ...@@ -320,7 +321,7 @@ class AttentionReconstructionLoss:
mem = self.norm(layer.norm_self_attn, mem) mem = self.norm(layer.norm_self_attn, mem)
c_mem = self.norm(layer.norm_self_attn, c_mem) c_mem = self.norm(layer.norm_self_attn, c_mem)
# Calculate attention with uncompressed memory # Calculate the attention with uncompressed memory
attn_mem = self.attn(layer.self_attn, h, mem, mem) attn_mem = self.attn(layer.self_attn, h, mem, mem)
# Calculate the attention with compressed memory # Calculate the attention with compressed memory
attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem) attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)
......
...@@ -54,7 +54,7 @@ class AutoregressiveModel(Module): ...@@ -54,7 +54,7 @@ class AutoregressiveModel(Module):
mem = [] mem = []
c_mem = [] c_mem = []
# Length of the memory (for masks) # Total length of the memory and compressed memory (for masks)
m_len = len(mem[0]) if mem else 0 m_len = len(mem[0]) if mem else 0
if c_mem: if c_mem:
m_len += len(c_mem[0]) m_len += len(c_mem[0])
...@@ -88,7 +88,7 @@ class Configs(NLPAutoRegressionConfigs): ...@@ -88,7 +88,7 @@ class Configs(NLPAutoRegressionConfigs):
""" """
## Configurations ## Configurations
The default configs can and will be over-ridden when we start the experiment. The default configurations can and will be overridden when we start the experiment.
""" """
model: AutoregressiveModel model: AutoregressiveModel
...@@ -132,7 +132,7 @@ class Configs(NLPAutoRegressionConfigs): ...@@ -132,7 +132,7 @@ class Configs(NLPAutoRegressionConfigs):
Concatenate new memories and compress the oldest memories. Concatenate new memories and compress the oldest memories.
""" """
# If it's configured not to use memory # If the configurations specify not to use memory
if self.mem_len == 0 and self.c_mem_len == 0: if self.mem_len == 0 and self.c_mem_len == 0:
return CompressedMemory([], []), [] return CompressedMemory([], []), []
...@@ -191,7 +191,7 @@ class Configs(NLPAutoRegressionConfigs): ...@@ -191,7 +191,7 @@ class Configs(NLPAutoRegressionConfigs):
mem_to_compress = [] mem_to_compress = []
# Return memories and the memories that were compressed. # Return memories and the memories that were compressed.
# Memories that were compressed is needed for the reconstruction loss computation. # Memories that were compressed are needed for the reconstruction loss computation.
return CompressedMemory(mem, c_mem), mem_to_compress return CompressedMemory(mem, c_mem), mem_to_compress
def step(self, batch: any, batch_idx: BatchIndex): def step(self, batch: any, batch_idx: BatchIndex):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册